diff --git a/.gitattributes b/.gitattributes index 0392bc358d3275b0e08aaee18a2cc53d2567d837..209f74f0e890874a6414498146fb10b5a454e9b0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -46,3 +46,4 @@ build/torch211-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter= build/torch211-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_C.py b/build/torch210-cxx11-cu128-x86_64-linux/_C.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2fd6df85149f4ecf481d67fcd12b52a929a7a4 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_C.py @@ -0,0 +1,194 @@ +import torch + +from ._ops import ops + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(value: int): + ops.set_mk_alignment_for_contiguous_layout(value) + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int, gran_k + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe, + num_groups=None, + is_sfa=None, + disable_ue8m0_cast=False, +): + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") + + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + recipe_len, + 0 if num_groups is None else num_groups, + num_groups is not None, + False if is_sfa is None else is_sfa, + is_sfa is not None, + disable_ue8m0_cast, + ) + + +def get_token_alignment_for_mega_moe() -> int: + return ops.get_token_alignment_for_mega_moe() + + +def get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch=True, + activation="swiglu", +): + num_bytes = ops.get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + + def slice_input_buffers(buffer): + return tuple( + ops.get_symm_buffer_views_for_mega_moe( + buffer, + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + ) + + return num_bytes, slice_input_buffers + + +def fp8_fp4_mega_moe( + y, + l1_weights, + l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + recipe, + activation, + activation_clamp, + fast_math, +): + l1_weights_data, l1_weights_sf = l1_weights + l2_weights_data, l2_weights_sf = l2_weights + r0, r1, r2 = recipe + ops.fp8_fp4_mega_moe( + y, + l1_weights_data, + l1_weights_sf, + l2_weights_data, + l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + r0, + r1, + r2, + activation, + activation_clamp, + fast_math, + ) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py index 8f0a7f80daf98c3979512b6fb75258a0f4cefdc5..8c4fe1c51ce5c419fc1b9db3b9f7e3ca03258c28 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -1,12 +1,18 @@ import os import subprocess +import sysconfig import torch +# Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton. +# In packaged/lazy-loaded use, that can outlive PyTorch's CUDA teardown and crash +# during interpreter shutdown. +os.environ.setdefault("DG_USE_TEMP_CUBLASLT_WORKSPACE", "1") + # Import the compiled extension -from ._ops import ops, add_op_namespace_prefix +from ._ops import ops as _ops, add_op_namespace_prefix from . import utils -__version__ = "2.3.0" +__version__ = "2.5.0" # ── Register fake tensor implementations for torch.compile ────────────────── @@ -32,6 +38,7 @@ for _op in [ "m_grouped_bf16_gemm_nn_contiguous", "m_grouped_bf16_gemm_nt_masked", "fp8_gemm_nt_skip_head_mid", + "fp8_fp4_mega_moe", ]: @torch.library.register_fake(add_op_namespace_prefix(_op)) @@ -58,10 +65,41 @@ def get_tc_util() -> int: return ops.get_tc_util() +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(alignment: int): + ops.set_mk_alignment_for_contiguous_layout(alignment) + + def get_mk_alignment_for_contiguous_layout() -> int: return ops.get_mk_alignment_for_contiguous_layout() +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + # Layout utilities @@ -77,10 +115,12 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( - sf, ks_tensor, ks_int + sf, ks_tensor, ks_int, gran_k ) @@ -88,16 +128,20 @@ def transform_sf_into_required_layout( sf, mn, k, - recipe=None, - recipe_ab=None, + recipe, num_groups=None, - is_sfa=False, + is_sfa=None, disable_ue8m0_cast=False, ): - has_recipe = recipe is not None - r0, r1, r2 = recipe if has_recipe else (0, 0, 0) - has_recipe_ab = recipe_ab is not None - rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") has_ng = num_groups is not None ng = num_groups if has_ng else 0 return ops.transform_sf_into_required_layout( @@ -107,13 +151,11 @@ def transform_sf_into_required_layout( r0, r1, r2, - has_recipe, - rab0, - rab1, - has_recipe_ab, + recipe_len, ng, has_ng, - is_sfa, + False if is_sfa is None else is_sfa, + is_sfa is not None, disable_ue8m0_cast, ) @@ -593,8 +635,37 @@ def fp8_mqa_logits( ) -def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): - return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) +def fp8_fp4_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, + logits_dtype=torch.float32, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + kv_data, kv_sf = kv + return ops.fp8_fp4_mqa_logits( + q_data, + q_sf, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + logits_dtype, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices) def fp8_paged_mqa_logits( @@ -606,6 +677,7 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits=False, + indices=None, ): return ops.fp8_paged_mqa_logits( q, @@ -616,6 +688,38 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits, + indices, + ) + + +def fp8_fp4_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, + logits_dtype=torch.float32, + indices=None, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + return ops.fp8_fp4_paged_mqa_logits( + q_data, + q_sf, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + logits_dtype, + indices, ) @@ -642,6 +746,14 @@ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) +from .mega import ( + SymmBuffer, + get_symm_buffer_for_mega_moe, + transform_weights_for_mega_moe, + fp8_fp4_mega_moe, +) + + # Initialize the C++ runtime @@ -683,6 +795,14 @@ if "DG_CUTLASS_INCLUDE" not in os.environ: _include, # legacy layout: include/cutlass os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout ] + for _site_packages in { + sysconfig.get_paths().get("purelib"), + sysconfig.get_paths().get("platlib"), + }: + if _site_packages: + _cutlass_include_candidates.append( + os.path.join(_site_packages, "cutlass_library", "source", "include") + ) for _cutlass_include in _cutlass_include_candidates: if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include @@ -703,8 +823,21 @@ def _ensure_initialized(): global _initialized if _initialized: return + _ops.init(_lib_root, _find_cuda_home()) _initialized = True - ops.init(_lib_root, _find_cuda_home()) + + +class _InitializedOps: + def __init__(self, raw_ops): + self._raw_ops = raw_ops + + def __getattr__(self, name): + if name != "init": + _ensure_initialized() + return getattr(self._raw_ops, name) + + +ops = _InitializedOps(_ops) # Try to initialize eagerly, but don't fail if CUDA is not found diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..7917e701aadb84de4a1e76bbe20834d53b7039ec --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2bff23699a1ab0aa2a92bab110612828e10cd623f2f626002ca4a1eba38668e +size 3381200 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py index 65e09b4e92d96545922fbce68acd103c33cd3845..d017d96b9d37776819ba7ab2e5d291158427f1a8 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _deep_gemm_cuda_8546a43 -ops = torch.ops._deep_gemm_cuda_8546a43 +from . import _deep_gemm_cuda_388adb9 +ops = torch.ops._deep_gemm_cuda_388adb9 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_deep_gemm_cuda_8546a43::{op_name}" + return f"_deep_gemm_cuda_388adb9::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/comm/barrier.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/comm/barrier.cuh new file mode 100644 index 0000000000000000000000000000000000000000..eb9858d8010db9088ae09ead48e6222a40f91075 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/comm/barrier.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::comm { + +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + +template +CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope) { + // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()` + static constexpr uint32_t kFinishSumTag = 0x80000000u; + sync_scope(); + if (thread_idx == 0) { + const auto count_ptr = workspace.get_grid_sync_count_ptr(); + const auto old_value = ptx::atomic_add_rel( + count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1); + uint32_t new_value; + do { + new_value = ptx::ld_acq(count_ptr); + } while (((new_value ^ old_value) & kFinishSumTag) == 0); + } + sync_scope(); +} + +template +CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, + const layout::SymBuffer& sym_buffer, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope, + const bool& sync_prologue = true, + const bool& sync_epilogue = true) { + DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads"); + + // Grid sync before NVLink signaling + if (sync_prologue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); + + // NVLink cross-rank barrier, only SM 0 participates + if (sm_idx == 0) { + auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr(); + const auto status = (*counter_ptr) & 3; + const auto signal_phase = status & 1, signal_sign = status >> 1; + auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase); + + // Send signals to remote ranks + if (thread_idx < kNumRanks) + ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1); + sync_scope(); + + // Update status and wait arrival (with 30s timeout, at 2 GHz) + constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll; + if (thread_idx == 0) { + ptx::red_add(counter_ptr, 1); + const int target = signal_sign ? 0 : static_cast(kNumRanks); + const auto start_clock = clock64(); + while (ptx::ld_acq_sys(signal_ptr) != target) { + if (clock64() - start_clock >= kNumTimeoutCycles) { + printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", + sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); + DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); + } + } + } + } + + // Grid sync after NVLink completion + if (sync_epilogue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); +} + +} // namespace deep_gemm::comm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/compile.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/compile.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e93c43fb77049ef91ca34490657db28bc132783b --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/compile.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include + +#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__) +#define DG_IN_CUDA_COMPILATION +#endif + +#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#else +#define CUTLASS_HOST_DEVICE_NOINLINE +#define CUTLASS_DEVICE_NOINLINE +#endif diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh index cd2aace7a8b8dd642f4c149bfc974c3d21e5f5b5..a3a8b62a2823835d14fbbfc26dd603680f2c5a02 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh @@ -1,5 +1,7 @@ #pragma once +#include + namespace cute { struct ignore_t { diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/exception.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/exception.cuh new file mode 100644 index 0000000000000000000000000000000000000000..78acf74755f9f1293b50198fbd74d96873354bc3 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/exception.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_UNIFIED_ASSERT +#ifdef DG_IN_CUDA_COMPILATION +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/math.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..03bee8f91cf10cd39dadebe8dc6cc2334baed65d --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/math.cuh @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::math { + +/// Pointer operations +template +CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) { + return reinterpret_cast(static_cast(ptr) + num_bytes); +} + +/// Math functions +template +CUTLASS_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE T align(T a, T b) { + return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} + +template +CUTLASS_DEVICE void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +#ifdef DG_IN_CUDA_COMPILATION +CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __ffma2_rn(a, b, c); +#else + return make_float2( + __fmaf_rn(a.x, b.x, c.x), + __fmaf_rn(a.y, b.y, c.y) + ); +#endif +} + +CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { +#if defined(__CUDA_ARCH__) + float ret; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +#else + return 1.0f / x; +#endif +} + +/// Casting +template +CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +CUTLASS_DEVICE float fast_pow2(const int& x) { + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +CUTLASS_DEVICE int fast_log2_ceil(float x) { + const auto bits = *reinterpret_cast(&x); + const auto exp = bits >> 23; + const auto man = bits & ((1 << 23) - 1); + return exp - 127 + (man != 0); +} + +template +CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { + DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); + const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; + const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto exp_x = fast_log2_ceil(scaled.x); + const auto exp_y = fast_log2_ceil(scaled.y); + sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); + sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y); +} + +/// Reduction +CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) { + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset); + if (lane_idx >= offset) + value += synced; + } + return value; +} + +// Operation functors +template struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +CUTLASS_DEVICE T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +CUTLASS_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} +#endif + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/tma_copy.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/tma_copy.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2c5bf708d49737b8912c991d856fa9d4ceb5b5d0 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/tma_copy.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::tma { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE void +copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +} // namespace deep_gemm::tma diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/types.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/types.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e07df0af8a95a2ae0c6f32493adaa5ec00c09633 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/types.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh index 8fb6c2fc53b6d1eb067d13c113462a9f7de4133a..3a5f7ad668878aced913e859780b39ce2c06d3e8 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh @@ -1,167 +1,24 @@ #pragma once -#include -#include #include -#include -#include -#include "cute_tie.cuh" +#include -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_TRAP_ONLY_DEVICE_ASSERT -#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) \ - asm("trap;"); \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) -#endif - -namespace deep_gemm { +namespace deep_gemm::utils { template struct PatternVisitor { FuncT func; - __device__ __host__ + CUTLASS_HOST_DEVICE explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} - __device__ __host__ - auto operator [](const uint32_t& i) { + CUTLASS_HOST_DEVICE + auto operator [](const uint32_t& i) const { return func(i); } }; -template -__device__ __host__ T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ T align(T a, T b) { - return ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_align(T a, T b) { - return constexpr_ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} - -template -__forceinline__ __device__ void swap(T& a, T& b) { - T temp = a; - a = b; - b = temp; -} - -__forceinline__ __device__ uint32_t get_sm_idx() { - uint32_t sm_idx; - asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); - return sm_idx; -} - -__forceinline__ __device__ uint32_t get_lane_idx() { - uint32_t lane_id; - asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float4 ld_shared(const float4* ptr) { - float4 ret; - asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { - uint4 ret; - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); -} - -__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { - asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); -} - -template -__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { - auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); - return *reinterpret_cast(&bf16x2); -} - -__device__ __forceinline__ void prefetch_l1(void *ptr) { - asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); -} - template struct Vectorized { static auto zeros() { @@ -180,4 +37,14 @@ struct Vectorized { using vec_t = decltype(zeros()); }; -} // namespace `deep_gemm` +template +CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if constexpr (kNumCols <= 32) return 32; + if constexpr (kNumCols <= 64) return 64; + if constexpr (kNumCols <= 128) return 128; + if constexpr (kNumCols <= 256) return 256; + return 512; +} + +} // namespace deep_gemm::utils diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bf0e460c8f636117969d81e21c00e0d2a2586d78 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M waves + constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M; + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]); + + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = base_m_idx + w * STORE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(base_n_idx + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = tmem_base_addr + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = smem_base_ptr + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared( + smem_ptr, + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]) + ); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +} + +} // namespace deep_gemm::epilogue diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f3f5351e6ac6cb0526bb6d2ca8abd5a99ebe45df --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd_swap_ab(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& effective_m, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows, + // implying STORE_BLOCK_N must be 128. + DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows"); + + // TMA checks + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumSwizzleAtomRows = 8; + DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M blocks + const auto num_stores = effective_m / STORE_BLOCK_M; + for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // Store stage offset + i * kNumSwizzleAtomRows; // In-block offset + uint32_t values[kNumSwizzleAtomRows]; + + // Warps cooperatively write an atomic block to shared memory + DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes"); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + // NOTES: Swizzling is not required in this case, but used here for consistency with other cases + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + uint32_t col = lane_idx / 4; + + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + ptx::st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } else { + // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements + // Start from lane index 0 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + // Start from lane index 16 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Destination shared memory address + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + + // Store matrix with transposition + ptx::SM90_U32x4_STSM_T::copy(math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (s == num_stores - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) { + auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + uint32_t n_idx = epilogue_type_t::apply_index_n(base_n_idx + i * STORE_BLOCK_N_ATOM); + + // Issue 2D or 3D TMA store + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx); + } + } + cute::tma_store_arrive(); + } + __syncwarp(); + } +} + +} // namespace deep_gemm::epilogue diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/transform.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/transform.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0266f4d402ab25878a792fb351b32ce1a04924cb --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/transform.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace deep_gemm::epilogue::transform { + +struct EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and + kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +} // namespace deep_gemm::epilogue::transform diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 0227b3e80061409c4dcf89f3f402ce408751246f..a60e2de8df85457a36145b77f06482d49eed0ed7 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -4,14 +4,18 @@ #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); - // Configs + // MMA Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; - constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); - DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - - // Utils - bool is_leader_cta = cute::block_rank_in_cluster() == 0; - const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - - // 2-CTA MMA + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 16; constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); - constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; - DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes - constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); @@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); // NOTES: Make sure we have enough shared memory for UMMA padding - static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); - DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); - - // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size - // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` - constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA"); // Real tensor memory size and offsets - constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { + if (warp_idx == 0) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_cd); } + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + // D/A/B shared memory - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; // Fill the tensor memory pointer @@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; @@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout, // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major @@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } @@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); // Arrive at full barriers @@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, // MMA issue warp // NOTES: only the leader CTA will do this // Make instruction descriptor - // TODO: refactor `UMMA_M` calculation - constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); - auto instr_desc = cute::UMMA::make_instr_desc(); + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc() + : cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // UMMA and empty barrier arrival alias auto umma_arrive = [](const uint64_t* barrier) { @@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting if (do_tmem_full_arrive) umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); }; + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + // Launch MMAs - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait TMA arrival full_barriers[stage_idx]->wait(phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA - using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + using mma_t = cute::conditional_t; + const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_block_idx > 0 or k > 0, - runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + if (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); } } } + __syncwarp(); // Commit to the mbarrier object // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` @@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kTensorCoreUtilControl < 100) { // For utilization control umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + __syncwarp(); // Wait for last UMMA to be done tensor_core_full_barrier->wait(tensor_core_phase); tensor_core_phase ^= 1; // Sleep for certain cycles - constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull; constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; - const auto& start_clock = clock64(); + const auto start_clock = clock64(); if (cute::elect_one_sync()) while (clock64() - start_clock < kNumDummyCycles) {} __syncwarp(); @@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // To safely deconstruct barriers, we need another round of waits - const auto& iter_idx = scheduler.current_iter - 1; + const auto iter_idx = scheduler.current_iter - 1; if (kNumMulticast > 1 and iter_idx >= 0) { - const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { @@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); - - // TMA checks - constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); - DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Share store pipeline between blocks uint32_t tma_stage_idx = 0; - auto advance_store_pipeline = [&]() { - tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; - }; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Wait UMMA arrival tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; - #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { - // Wait shared memory to be released - if (epilogue_warp_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - - // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; - - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], - n_idx, m_idx, scheduler.current_group_idx); - } else { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); - } - cute::tma_store_arrive(); - } - } + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx< + (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 86303347d9c7a3a93b65a16d6ad4a7b73eb2ad1a..13bb087232772ac1e9d65997f733164ed5827c49 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -5,18 +5,19 @@ #include #include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1) sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); @@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, } // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Fill D/A/B - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); // Fill the tensor memory pointer @@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, __syncthreads(); // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx == 0) { // TMA load warp for (uint32_t s = 0; s < num_total_stages; ++ s) { @@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Issue TMAs if (cute::elect_one_sync()) { - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); } // Arrive at full barriers @@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, auto instr_desc = cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, "Invalid MMA instruction shape"); // Wait tensor memory empty barrier arrival - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrival const auto& stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); @@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); } } @@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory if (warp_idx == 2) - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // TMA checks constexpr uint32_t kNumBankGroupBytes = 16; @@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } // Synchronize all threads and issue TMA diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b8a99fd04273d48a6b500b6e76f1e938be8858da --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, + const uint32_t logits_stride, + const uint32_t* cu_seq_len_k_start, + const uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + + // Allocate tensor memory + if (warp_idx == kSpecWarpStart + 2) + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + __syncthreads(); + + // Scheduler + const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv); + seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv); + start = cute::min(start, seq_k_start[i]); + end = cute::max(end, seq_k_end[i]); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {start, math::ceil_div(end - start, BLOCK_KV)}; + }; + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + // Enumerate Q blocks + if (cute::elect_one_sync()) { + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx], + kv_start + kv_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize umma desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this into `deep_gemm/ptx/tcgen05.cuh` + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline. Without this, UMMA can consume + // kNumQStages Q blocks before math warps release any, causing a + // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q + // -> Math waits full_tmem -> UMMA (already moved on). + empty_q_barriers[q_stage_idx]->arrive(); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = threadIdx.x; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr uint32_t N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[BLOCK_Q][kNumHeads]; + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + // TODO: optimize bank conflicts + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Calculate KV offset in advance + auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + // TODO: optimize performance + const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast(logits_stride); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + + // Release last Q empty + empty_q_barriers[q_stage_idx]->arrive(); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d9add53425517d936ed201f78d277db775d19507 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -0,0 +1,510 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + // Initialize outside valid range to indicate no previous task + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, _, __; + while (scheduler.fetch_next_task(q_atom_idx, _, __)) { + // Issue TMA Q when (q_idx, atom_idx) changes + if (q_atom_idx != last_q_atom_idx) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + last_q_atom_idx = q_atom_idx; + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, num_kv; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) { + // Reset block table cache on kv restart + if (q_atom_idx != last_q_atom_idx) + kv_block_idx_ptr = 32; + last_q_atom_idx = q_atom_idx; + + // Coalesced load of block table + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + + // Broadcast KV block indices + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`"); + + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + + // Issue TMA KV + if (cute::elect_one_sync()) { + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i, + 0, 0, kv_block_idx[i]); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + // Wait TMA Q arrivals + uint32_t q_stage_idx, q_phase; + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + } + last_q_atom_idx = q_atom_idx; + + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize UMMA desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this PTX into headers + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[kNextNAtom][kNumHeads]; + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release last Q empty + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrivals + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j)); + weights[i][j + 0] = raw.x; + weights[i][j + 1] = raw.y; + weights[i][j + 2] = raw.z; + weights[i][j + 3] = raw.w; + } + } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } + } + last_q_atom_idx = q_atom_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0bc6a3fe26e61057fbcfcc5f4c63d4faa6e475fe --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh @@ -0,0 +1,514 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // SF configs + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + // NOTES: Make sure we have enough shared memory for UMMA padding + constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4); + const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = reinterpret_cast(smem_b[kNumStages]); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]);; + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + uint32_t sfa_m_idx = m_block_idx * BLOCK_M; + uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>( + shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)); + tma::copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = scheduler.template get_global_idx( + shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx); + tma::copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled() + : cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll 4 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // Do SF copy at certain stages + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + + // Issue UMMA + using mma_t = cute::conditional_t< + kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto runtime_instr_desc = kSwapAB ? + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id): + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + a_desc.lo = mma::sm100::advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + if constexpr (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b2adc6c7ad40cc84aef802a418c3702287774b20 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -0,0 +1,1380 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // UTCCP 4x32 transpose index mapping within each 128-element group + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // Make instruction descriptor with block scaling + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603add3f494aed51dce7aec53b5545bdc23f4..7ce008e5ea30ff8ad5ce65f0f3051d5f663c50df 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 180a308b3279b38827741942917a31e103b15b52..e6744f59ac68a5b7a681ef4ff9ad985fdb5f5e51 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -6,27 +6,31 @@ #include #include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warp_in_group_idx = warp_idx % 4; - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); // Shared memory configs // NOTES: weight may be unaligned @@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); // Align to 512 bytes for swizzle-64B extern __shared__ __align__(512) uint8_t smem_buffer[]; @@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); // Tensor memory allocation auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); // Initialize barriers DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); - const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { full_kv_barriers[i]->init(1); empty_kv_barriers[i]->init(kNumMathThreads); } - #pragma unroll - for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - full_umma_barriers[i]->init(1); - empty_umma_barriers[i]->init(128); - } - - // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); - } else if (is_umma_warp) { + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } // Allocate tensor memory cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); } __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 24; - constexpr uint32_t kNumMathRegisters = 240; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase @@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; - if (is_tma_load_warp) { + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { cutlass::arch::warpgroup_reg_dealloc(); // Prefetch - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } num_total_kv_blocks += num_kv_blocks; @@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_descwait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } num_total_kv_blocks += num_kv_blocks; + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline + empty_q_barriers[q_stage_idx]->arrive(); + // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } else if (warp_idx >= kNumMathThreads / 32) { + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { cutlass::arch::warpgroup_reg_dealloc(); - } else if (warp_idx < kNumMathThreads / 32) { + } else if (warp_idx < kSpecWarpStart) { cutlass::arch::warpgroup_reg_alloc(); // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const auto& warp_offset = warp_idx * 32; - const auto& v_offset = lane_idx; + const auto tmem_start = warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Preload weights - constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); - float weights[BLOCK_Q][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[BLOCK_Q][kNumHeads]; while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Read weights #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); - } + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } // Compute over KV blocks @@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); - uint32_t shifted_accum[kNumLDTMElems]; - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + // Load accumulator from TMEM + float accum[kNumHeads]; + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + } + // Accumulate weighted ReLU in parallel auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + const auto transform = [&](const uint32_t& j, const float2& sum) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]); return __ffma2_rn(a, b, sum); }; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); - } - - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); } auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + auto result = static_cast(scale_kv * (sum.x + sum.y)); // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { - if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; } else { - logits[q_idx * stride_logits + kv_offset + v_offset] = result; + logits[q_offset + kv_offset] = result; } + __syncwarp(); } } num_total_kv_blocks += num_kv_blocks; @@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } - // Free tensor memory - __syncthreads(); - if (is_tma_load_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 7058c40f4f195de94184d3e7ebc6f9aa2eb3670f..9a5bddbf37ef0f0ce679ef7f553ee6084b92a44c 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -6,56 +6,65 @@ #include #include +#include +#include +#include #include -#include -#include - -#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; - static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); - static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q and KV data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); // Barriers and TMEM pointer on shared memory const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); - constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); - const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); - const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { @@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } cutlass::arch::fence_barrier_init(); } - if (is_umma_warp) { + if (warp_idx == kSpecWarpStart + 1) { if (cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { @@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; - uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings // Construct instruction with layout D constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - constexpr uint32_t UMMA_N = kNextN * kNumHeads; + constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); - if (is_tma_load_warp) { - // TMA warp-group for loading data + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading data cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { if (cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx, num_kv; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; bool fetched_next_task; // Prefetch the first Q - if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) - issue_tma_q(0, next_q_idx), q_iter_idx = 1; + if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1; - int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_ptr = 32; uint32_t kv_block_idx_storage; while (fetched_next_task) { - // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); - q_idx = next_q_idx; + // Prefetch next Q when (q, atom) changes + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); + + if (q_atom_idx != next_q_atom_idx) + kv_block_idx_ptr = 32; + + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { + // TODO(xuzhean): consider -1 + if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); - issue_tma_q(q_stage_idx, q_idx + 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); } - int kv_block_idx[kNumBlocksPerSplit]; + uint32_t kv_block_idx[kNumBlocksPerSplit]; #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); kv_block_idx_ptr += kNumBlocksPerSplit; @@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, if (cute::elect_one_sync()) { #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, - 0, 0, 1, kv_block_idx[i]); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, - 0, kv_block_idx[i]); + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } // Fetch next task - fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv); } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_desc(); auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 1; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) { + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + if (q_atom_idx != next_q_atom_idx) { + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); full_q_barriers[q_stage_idx]->wait(q_phase); } - q_idx = next_q_idx; + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; + // Wait KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); @@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(umma_phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } umma_phase ^= 1; } - } else if (is_math_warp) { - // Math warp-groups for WGMMA + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const uint32_t thread_idx = threadIdx.x; + const auto math_warpgroup_idx = warpgroup_idx; + const auto tmem_start = math_warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Weights - constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); - float weights[kNextN][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[kNextNAtom][kNumHeads]; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 0; + bool is_paired_atom = false; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - // Current Q changes - if (q_idx != next_q_idx) { - // Release Last Q empty + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + // Q or atom changes + if (q_atom_idx != next_q_atom_idx) { + // Release last Q empty if (q_iter_idx > 0) empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); @@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); } } - // Get current Q and KV index - q_idx = next_q_idx; + // Get current task indices + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase); - tcgen05_after_thread_sync(); + full_umma_barriers[math_warpgroup_idx]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty @@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Reduce over the head dim and store DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - - #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + float accum[kNumHeads]; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); } - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); - } - - auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; - // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + thread_idx] = result; + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } - } else { - cutlass::arch::warpgroup_reg_dealloc(); - } - // Free tensor memory - __syncthreads(); - if (is_umma_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh index 4e4ff21d0746cff7bc7ecaf23a49278a2f5810cc..aaf7fd9aea773fc66a696f5c9382b8b0e53e263d 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -4,20 +4,22 @@ #include -#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { // Calculate the index of the bank group to be written in the atom - const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); // Reshape the atom in another view and swizzle // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` @@ -37,7 +39,7 @@ template -__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Prefetch TMA descriptors at the very beginning if (warp_idx == 0 and cute::elect_one_sync()) { @@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; // Fill the tensor memory pointer @@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t m_offset = shape_m * k_split_idx; const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Dispatch warps into different roles if (warp_idx < kNumMMAThreads / 32) { // TMA load warp @@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; // Checks for MMA instructions @@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& stage_idx = s % kNumStages; const auto& cast_stage_idx = s % kNumCastStages; full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); @@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); } @@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); @@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); if constexpr (BLOCK_M == 64) __syncwarp(); } @@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, #pragma unroll for (uint32_t i = 0; i < kNumLoads; i += 2) { auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); - sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], - uint32_values[0][i + 1], uint32_values[1][i + 1], - smem_ptr); + ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); } // Wait tensor memory empty @@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, cutlass::arch::fence_view_async_tmem_store(); // Arrive for issuing MMAs - tcgen05_before_thread_sync(); + ptx::tcgen05_before_thread_sync(); full_cast_barriers[cast_stage_idx]->arrive(); } // Intra-warp reduction and write back #pragma unroll for (uint32_t u = 0; u < 2; ++ u) { - const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); - const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; if (lane_idx % 4 == 0 and m_idx < shape_m) sqr_sum[m_offset + m_idx] = reduced_sum; } diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 7a77e4e8fbbbffa56e8c8632ade7ae7938b30ee9..84a149eb9b6b35a907f03b4c04434ee9f8e558ee 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -11,14 +11,19 @@ #include #include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); @@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout, // D/A/B shared memory auto smem_d = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumTMARegisters = 48; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; - const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Issue TMAs constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } } @@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); - auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm90::make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = mma::sm90::make_gmma_desc(smem_b[0], 0, 0); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, }; // TODO: remove some useless computation for unaligned Ms - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; - a_desc.reg32_[0] = advance_gmma_desc_lo( + const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); - b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr @@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); - st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } } } @@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh index 191a4fe2c4ccf66b0743affedcbfd17950e2618f..7c344296519e7dd0852a8940d3e9d714b12a5646 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -4,26 +4,32 @@ #include #include +#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, float *d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); @@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Align to 1024 bytes for swizzle-128B // Fill shared memory pointers extern __shared__ __align__(1024) uint8_t smem_buffer[]; - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, constexpr uint32_t kNumMathRegisters = 232; // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, #pragma unroll for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; - const uint32_t& k_idx = sk_idx % SHAPE_K; - const uint32_t& s_idx = sk_idx / SHAPE_K; + const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t k_idx = sk_idx % SHAPE_K; + const uint32_t s_idx = sk_idx / SHAPE_K; constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } @@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrivals - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, 1); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave empty_barriers[stage_idx]->arrive(); } - const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; - const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { if (col + i * 8 >= SHAPE_N) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index cdd28fcb59d3b038c84c007ef1da1477d7ca263a..195d431f9067abcd94ce3c27e1ea7bf60ada7224 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -6,18 +6,26 @@ #include #include +#include #include #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, int* grouped_layout, cute::TmaDescriptor* tensor_map_buffer, @@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); // Configs @@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Tensor maps on shared and global memory - auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); - }); - auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); - }); - auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); - auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + auto smem_tensor_map_a = reinterpret_cast(smem_buffer); + auto smem_tensor_map_b = smem_tensor_map_a + 1; + auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2; + auto gmem_tensor_map_b = gmem_tensor_map_a + 1; // Data on shared memory auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); - auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); }); - auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); }); // Barriers on shared memory constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); }); if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // Load tensormap A/B to shared memory if constexpr (kGemmType == GemmType::KGroupedContiguous) { - *smem_tensor_map_a[0] = tensor_map_a_base; - *smem_tensor_map_a[1] = tensor_map_a_base; - *smem_tensor_map_b[0] = tensor_map_b_base; - *smem_tensor_map_b[1] = tensor_map_b_base; + *smem_tensor_map_a = tensor_map_a_base; + *smem_tensor_map_b = tensor_map_b_base; } // Initialize barriers @@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // TMA and MMA pipeline - const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase }; uint32_t iter_idx = 0; @@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // NOTES: only one thread (or warp) will be used if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { - const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; - const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; - uint32_t last_group_idx = kNumGroups, sum_k = 0; + uint32_t last_group_idx = kNumGroups; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); - const uint32_t& m_idx = m_block_idx * BLOCK_M; - const uint32_t& n_idx = n_block_idx * BLOCK_N; - - if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { - const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; - const uint32_t& next_stage_idx = stage_idx ^ 1; + + const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t m_idx = m_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) { last_group_idx = scheduler.current_group_idx; - // Prepare next tensor map - sum_k += scheduler.current_shape_k; - if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); - *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); - tensor_map_release_cta(); - } - - // Get current tensor map - if (scheduler.current_num_valid_groups > 0) { - tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); - tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); - current_tensor_map_a = gmem_tensor_map_a[stage_idx]; - current_tensor_map_b = gmem_tensor_map_b[stage_idx]; - } + // Directly update current tensor map + const uint64_t current_k_offset = scheduler.current_k_cumsum; + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m); + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); + *(gmem_tensor_map_b) = *(smem_tensor_map_b); + ptx::tensor_map_release_gpu(); + + // Immediately acquire current tensor map + ptx::tensor_map_acquire_gpu(gmem_tensor_map_a); + ptx::tensor_map_acquire_gpu(gmem_tensor_map_b); } #pragma unroll kNumPipelineUnrolls @@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Issue TMA auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& k_idx = k_block_idx * BLOCK_K; - const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base); + const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base); + tma::copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma::copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma::copy(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma::copy(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Accumulation for WGMMA or CUDA promotion DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); - const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); - const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; float2 scales_b[WGMMA::kNumAccum / 4]; @@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); - auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1); // Read B scales #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + scales_b[i] = ptx::ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cutlass::arch::NamedBarrier::sync(128, math_wg_idx); // Store to D shared memory - const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); - const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + const auto smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); - st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 9247304cdd17d8e2c3a5cdb31c78c191ae6b76ec..aa412484debb328df8f4f4d0d7cdfc1c61ec7b69 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,17 +10,21 @@ #include #include -#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { +CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { if (num_former_iters == kNumFormerIters) { func(cute::Int{}); return; @@ -35,12 +39,12 @@ template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT( + math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or + (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); - const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K); + const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K); + const uint32_t smem_sfb_size = math::align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // NOTES: Make sure we have enough shared memory for WGMMA padding static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); // Configs - const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Data on shared memory auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); }); auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); @@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, &full_barrier, + tma::copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), num_tma_multicast_a, batch_idx); - tma_copy(&tensor_map_sfa, &full_barrier, - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + tma::copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, &full_barrier, + tma::copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); @@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); - auto b_desc = make_smem_desc(smem_b[0], 1); + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]); } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); @@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Skip useless computations if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { // The compiler must know the dynamic variable `num_former_iters`'s real value - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; // Dispatch `num_former_iters` and launch MMAs dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { #pragma unroll 8 for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; - auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; WGMMA::wgmma(a_desc, b_desc, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) @@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + const bool predicate = kMustUseUniformedScaleB or i < num_former_iters; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index d58c716242a09922157aa13e16cb8afac477904c..225af4416810b2680317d3713c372807e548f464 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -7,36 +7,31 @@ #include #include +#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - -// ReSharper disable once CppNotAllPathsReturnValue -template -static constexpr int to_swizzle_cute_type() { - DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); - if constexpr (kHeadDim == 32) - return static_cast(cute::SM90::GMMA::LayoutType::B32); - if constexpr (kHeadDim == 64) - return static_cast(cute::SM90::GMMA::LayoutType::B64); - if constexpr (kHeadDim == 128) - return static_cast(cute::SM90::GMMA::LayoutType::B128); -} - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumSMs, + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TODO: consider TMA multicast // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // Prefetch TMA descriptors @@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); // Initialize barriers - const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; if (is_tma_load_warp and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { @@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t kNumMathRegisters = 112; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + const auto sm_idx = blockIdx.x; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase }; }; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Prefetch const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& thread_idx = threadIdx.x % kNumMathThreads; const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto& lane_idx = ptx::get_lane_idx(); float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; const auto& warp_offset = warp_idx * 16; @@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } // Compute over KV blocks @@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); // Issue WGMMA DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast(v_0); if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast(v_1); } else { - logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; - logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + logits[q_offset + kv_offset + v_0_offset] = static_cast(v_0); + logits[q_offset + kv_offset + v_1_offset] = static_cast(v_1); } } } diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 482a85a80fce29aa949b464070b0b20fb55ae030..cc2592bb402af88f3d7c7b841f26e1961093c8a3 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -6,133 +6,46 @@ #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(32, 1) -void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, - const uint32_t* context_lens, uint32_t* schedule_metadata) { - DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); - const uint32_t lane_idx = get_lane_idx(); - - uint32_t num_segs[kAlignedBatchSize / 32]; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - const uint32_t q_idx = k * 32 + lane_idx; - const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); - const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); - num_segs[k] = ceil_div(context_len, SPLIT_KV); - } - - __shared__ uint32_t prefix_sum[kAlignedBatchSize]; - uint32_t sum = 0; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - uint32_t x = num_segs[k]; - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset <<= 1) { - const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); - x += (lane_idx >= offset ? y : 0); - } - x += sum; - prefix_sum[k * 32 + lane_idx] = x; - sum = __shfl_sync(0xffffffff, x, 31); - } - - const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; - for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { - uint32_t seg_starts = sm_idx * q + min(sm_idx, r); - uint32_t q_idx = 0; - while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) - ++ q_idx; - const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); - __syncwarp(); - - schedule_metadata[sm_idx * 2] = q_idx; - schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; - } -} - -template -struct PagedMQALogitsScheduler { - uint32_t batch_size; - const uint32_t* context_lens; - - uint32_t current_q_idx, current_kv_idx; - uint32_t end_q_idx, end_kv_idx; - uint32_t current_num_kv; - - __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { - const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); - return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; - } - - __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, - const uint32_t* context_lens, const uint32_t* schedule_meta) { - this->batch_size = batch_size; - this->context_lens = context_lens; - - const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); - const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; - - current_num_kv = get_num_kv(current_q_idx); - } - - __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { - q_idx = current_q_idx; - kv_idx = current_kv_idx; - num_kv = current_num_kv; - - if (q_idx == end_q_idx and kv_idx == end_kv_idx) - return false; - - current_kv_idx += kNumBlocksPerSplit; - if (current_kv_idx >= current_num_kv) { - ++ current_q_idx; - current_kv_idx = 0; - current_num_kv = get_num_kv(current_q_idx); - } - - return true; - } - - __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { - return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; - } -}; - -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits"); + // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; @@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q data and barriers on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); }); auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); // Separate math warpgroups and tma load warps into KV groups // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); // Per group KV data and barriers on shared memory - const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); // Initialize barriers if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumTMARegisters = 64; constexpr uint32_t kNumMathRegisters = 104; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; uint32_t q_iter_idx = 0, kv_iter_idx = 0; @@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_group_idx >= kNumMathWarpGroups) return; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, while (fetched_next_task) { // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1)); q_idx = next_q_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; @@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_idx == 0 or kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + block_table[q_idx * static_cast(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0); } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); @@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Issue TMA KV if (cute::elect_one_sync()) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, cutlass::arch::warpgroup_reg_alloc(); float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const auto sub_warp_offset = (warp_idx % 4) * 16; + const auto v_0_offset = lane_idx / 4 + 0; + const auto v_1_offset = lane_idx / 4 + 8; // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, for (uint32_t i = 0; i < kNextN; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } } @@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * static_cast(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` // Wait TMA KV arrival @@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); // Wait WGMMA - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Inter-thread reduction #pragma unroll for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = static_cast(1u << j); + const auto offset = static_cast(1u << j); v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); } // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v_0; - logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + logits[kv_offset + i * static_cast(logits_stride) + v_0_offset] = static_cast(v_0); + logits[kv_offset + i * static_cast(logits_stride) + v_1_offset] = static_cast(v_1); } } } diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh index e3bf98478923a2bf560e69e6ecc802d218fb82c1..93b14100109c282fe4705af37bcf84547e20b3f5 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -5,20 +5,23 @@ #include #include -#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; - const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; @@ -35,7 +38,7 @@ template -__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = 256; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // TMA load warp if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cutlass::arch::warpgroup_reg_dealloc(); for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); // Compute offsets @@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); } } else if (warp_idx < kNumMathThreads / 32) { @@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t WGMMA_N = BLOCK_N; constexpr uint32_t WGMMA_K = 8; - using WGMMA = typename TF32MMASelector::type; + using WGMMA = typename mma::sm90::TF32MMASelector::type; float accum[WGMMA::kNumAccum] = {0}; constexpr uint32_t kNumBankGroupBytes = 16; @@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); if (s > 0) empty_barriers[(s - 1) % kNumStages]->arrive(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; @@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { #pragma unroll for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { - auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + auto b_desc = mma::sm90::make_smem_desc( + smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); } - const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); - const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1); const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); if (lane_idx % 4 == 0) { @@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, if (m_idx + 8 < shape_m) sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); // Write accum to shared memory @@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // 0/1 write to the same row, 2/3 write to another row auto values = reinterpret_cast(accum + i * 2); - st_shared(smem_ptr, values[0], values[1]); - st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1]); + ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, 1); diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh index cc9e5e6b0c7ce95acf0b7149221dc4d4f0f83a21..2f66b980c5f2c6d8c51e85b4feb47d9efefe1b64 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -3,21 +3,24 @@ #include #include -#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(kNumWarps * 32, 1) +template +CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1) void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, - const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { - const uint32_t& num_sms = gridDim.x; - const uint32_t& sm_idx = blockIdx.x; - const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - constexpr float neg_inf = -cute::numeric_limits::infinity(); + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) { + const uint32_t num_sms = gridDim.x; + const uint32_t sm_idx = blockIdx.x; + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t); + const logits_dtype_t neg_inf = -cute::numeric_limits::infinity(); // Allocate filled `-inf` shared memory - extern __shared__ __align__(1024) float smem_buffer[]; + extern __shared__ __align__(1024) logits_dtype_t smem_buffer[]; #pragma unroll for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) smem_buffer[i] = neg_inf; @@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const __syncthreads(); // Assign sequence to each warp - const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, - const uint32_t& start, const uint32_t& total) -> cute::tuple { - const auto& per = total / num, rem = total % num; - return {start + idx * per + min(idx, rem), per + (idx < rem)}; + const auto assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto per = total / num, rem = total % num; + return {start + idx * per + cute::min(idx, rem), per + (idx < rem)}; }; CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (cute::elect_one_sync()) { for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { - const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + const auto right = cute::min(left + BLOCK_KV, static_cast(stride_logits)); if (right <= ks or ke <= left) { - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t)); } else { if (left < aligned_ks) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t)); if (aligned_ke < right) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t)); } } } } + __syncwarp(); for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t j = aligned_ks; j < ks; ++ j) logits[i * stride_logits + j] = neg_inf; for (uint32_t j = ke; j < aligned_ke; ++ j) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh index bea7000276c3e382c1acfeff545d6181351849b6..a977c5547217363e545498f7ca25ee6108056afb 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh @@ -1,13 +1,16 @@ #pragma once +#include #include +#include +#include namespace deep_gemm { template -__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { - typedef typename Vectorized::vec_t in_vec_t; +CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename utils::Vectorized::vec_t in_vec_t; constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; @@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { extern __shared__ float smem_buffer[]; constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the block sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { - auto in_vec = __ldg(local_sf + i); + auto in_vec = local_sf[i]; const auto& in_values = reinterpret_cast(&in_vec); const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; @@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; - out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); } } // NOTES: the two kernels below always pack the K dimension template -__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { +CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { extern __shared__ uint32_t smem_buffer[]; // Shapes and strides - constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u); constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the group sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load FP32 SFs DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); @@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con const auto num_uint4 = num_values / 4; #pragma unroll for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { - const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); - st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + const auto& [x, y, z, w] = reinterpret_cast(local_sf)[i]; + ptx::st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); } // Fill unaligned values as well if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) - st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]); __syncthreads(); // Pack into UE8M0 and store @@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { const auto sf_k_idx = sf_k_pack_idx * 4 + j; - values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; } // Pack and store @@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con template -__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, - const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { +CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k, + const uint32_t gran_k) { // Always packing the K dimension // NOTES: should also assert `mn % 4 == 0` at launch DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); @@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, // Each warp is responsible for a packed row const auto warp_idx = threadIdx.x / 32; - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; if (warp_idx >= in_block_packed_sf_k) return; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Make an offset on the input uint32_t input_offset = 0; if constexpr (kNumGroups > 1) { @@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, #pragma unroll for (uint32_t i = 0; i < 4; ++ i) { const auto group_idx = lane_idx * 4 + i; - group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0; } __syncwarp(); // Make the offset sf_k = 0; - auto sum_packed_sf_k = 0; + uint32_t sum_packed_sf_k = 0; #pragma unroll for (uint32_t i = 0; i < kNumGroups; ++ i) { - const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4); sf_k += sf_k_in_group; - sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u); if (packed_sf_k_idx < sum_packed_sf_k) break; if (const auto remainder = sf_k_in_group % 4; remainder > 0) @@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, } } - for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { // Load uint4 values[4]; #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { values[j] = make_uint4(0, 0, 0, 0); if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) - values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + values[j] = reinterpret_cast(sf + sf_k_idx * mn)[mn_idx]; } // Pack and store diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..13520c60e29b37b6ab8ebed4bc0fd8ac26bbb63e --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include + +#include +#include + +namespace deep_gemm::layout { + +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, + T num_experts_per_rank) { + const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; + const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); + return math::constexpr_align( + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); +} + +// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) { + return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast(128)); +} + +// Per-token source metadata for combine write-back +struct TokenSrcMetadata { + uint32_t rank_idx; + uint32_t token_idx; + uint32_t topk_idx; +}; + +struct Workspace { + void* base; + uint32_t num_ranks, num_experts; + uint32_t num_experts_per_rank; + uint32_t num_max_tokens_per_rank; + uint32_t num_max_recv_tokens_per_expert; + + // Pool capacity: all local experts share a contiguous token pool + uint32_t num_max_pool_tokens; + uint32_t num_max_pool_blocks; + + // For both grid barrier and NVLink barrier + static constexpr uint64_t kNumBarrierSignalBytes = 32; + + CUTLASS_HOST_DEVICE + Workspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + base(base), + num_ranks(num_ranks), num_experts(num_experts), + num_max_tokens_per_rank(num_max_tokens_per_rank) { + num_experts_per_rank = num_experts / num_ranks; + num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // Grid sync counters: `kNumBarrierSignalBytes` layout + // [ 0..15]: 4 x `uint32_t` grid sync counters + // [16..20]: `uint32_t` NVLink barrier counter + // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1) + static constexpr uint32_t kNumMaxGridSyncCounters = 4; + + template + CUTLASS_DEVICE + uint32_t* get_grid_sync_count_ptr() const { + DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds"); + return static_cast(base) + kIndex; + } + + CUTLASS_DEVICE + uint32_t* get_nvl_barrier_counter_ptr() const { + return static_cast(base) + kNumMaxGridSyncCounters; + } + + CUTLASS_DEVICE + int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const { + // NOTES: the signal is signed, as we may minus + return math::advance_ptr(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int)); + } + + CUTLASS_DEVICE + uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const { + return math::advance_ptr(base, kNumBarrierSignalBytes) + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts * 2) + expert_idx; + } + + CUTLASS_DEVICE + uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const { + const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank); + return reinterpret_cast(base) + pool_block_idx; + } + + CUTLASS_DEVICE + uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const { + // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned + const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u)); + return reinterpret_cast(base) + pool_block_idx; + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + expert_idx * (num_ranks * num_max_recv_tokens_per_expert) + + rank_idx * num_max_recv_tokens_per_expert + token_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast(get_src_token_topk_idx_ptr(num_experts_per_rank)); + return base + pool_token_idx; + } +}; + +struct Data { + uint32_t num_bytes; + bool require_tma_alignment; + void* base; + + CUTLASS_HOST_DEVICE + constexpr explicit Data( + const uint32_t& num_bytes, + const bool& require_tma_alignment = true, + void* base = nullptr) : + num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) { + DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment); + } + + template + CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const { + return static_cast(num_bytes); + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) { + base = ptr; + } +}; + +struct Buffer { + Data data_layout; + uint32_t num_ranks; + uint32_t num_max_tokens_per_rank; + + void* base; + + CUTLASS_HOST_DEVICE + Buffer(const Data& data_layout, + const uint32_t& num_ranks, + const uint32_t& max_num_tokens_per_rank, + void* base = nullptr) : + data_layout(data_layout), + num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank), + base(base) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes_per_rank() const { + return num_max_tokens_per_rank * data_layout.get_num_bytes(); + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + return get_num_bytes_per_rank() * num_ranks; + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + CUTLASS_HOST_DEVICE + Buffer get_rank_buffer(const uint32_t& rank_idx) const { + return { + data_layout, + 1, num_max_tokens_per_rank, + math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx) + }; + } + + CUTLASS_HOST_DEVICE + Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const { + DG_DEVICE_ASSERT(num_ranks == 1 or global); + return Data( + data_layout.num_bytes, + data_layout.require_tma_alignment, + math::advance_ptr(base, data_layout.get_num_bytes() * token_idx) + ); + } +}; + +} // namespace deep_gemm::layout diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7f11aabc912b82d616779e9999ecfd00d19c9b93 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace deep_gemm::layout { + +constexpr static uint32_t kNumMaxRanks = 72; + +template +struct SymBuffer { + int64_t base; + int64_t offsets[kNumMaxRanks]; + uint32_t rank_idx; + + DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks"); + + SymBuffer() = default; + + template + explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) { + const auto size = static_cast(c.size()); + base = c[rank_idx]; + for (uint32_t i = 0; i < kNumMaxRanks; ++ i) + offsets[i] = i < size ? (c[i] - base) : 0; + } + +#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__) + template + CUTLASS_DEVICE ptr_t get_base_ptr() const { + return reinterpret_cast(base); + } + + template + CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { + int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast(ptr); + return *reinterpret_cast(&mapped_ptr); + } +#endif +}; + +} // namespace deep_gemm::layout diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm100.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm100.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0c554f4cd65c253582294152c8e72e79ccd92a42 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm100.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace deep_gemm::mma::sm100 { + +/// Shared memory descriptor +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +CUTLASS_DEVICE +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +CUTLASS_DEVICE +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +/// UMMA descriptors +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size(); +} + +template +CUTLASS_DEVICE +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto layout_type = to_umma_layout_type(); + const auto num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id( + cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +} // namespace deep_gemm::mma::sm100 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm90.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm90.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2c061940deef5a25849173c6d052eed4f0d24130 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm90.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace deep_gemm::mma::sm90 { + +/// MMA +template +struct FP8MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + template + CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +/// Shared memory descriptor +template +CUTLASS_DEVICE cute::GmmaDescriptor +make_smem_desc(PointerType smem_ptr, const int& layout_type, + const uint32_t& leading_byte_offset = 0, + const uint32_t& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +CUTLASS_DEVICE +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +} // namespace deep_gemm::mma::sm90 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c3e03bec73d858c77b9f393ee091a52b5bdd01ac --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Compatibility: 256 bits LD/ST instructions +#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000 +using longlong4_t = longlong4_32a; +#define make_longlong4_t make_longlong4_32a +#else +struct alignas(32) longlong4_t { long long x, y, z, w; }; +CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t( + const long long& x, const long long& y, const long long& z, const long long& w) { + return {x, y, z, w}; +} +#endif + +/// LD/ST matrix +// TODO: remove `struct` +struct SM90_U32x2_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +template +struct SM90_U32x2_STSM_N { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM100_U8x4_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src = *reinterpret_cast(&src_0); + asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src)); + } +}; + +template +struct SM100_U8x8_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +/// Shared memory +CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) { + // `size` must be 64-bit before PTX ISA 9.0 + asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" :: + "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast(num_bytes))); +} + +/// Global memory +CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +/// Atomics +CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) { + uint32_t ret; + asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) { + asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE int ld_acq_sys(const int* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +/// Predicated loads +CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) { + longlong4_t ret = make_longlong4_t(0, 0, 0, 0); + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " setp.ge.s32 p, %5, 0;\n\t" + " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t" + "}" + : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w) + : "l"(ptr), "r"(pred) + : "memory"); + return ret; +} + +/// Prefetch +CUTLASS_DEVICE void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh new file mode 100644 index 0000000000000000000000000000000000000000..528b3dd10318a5d7493ec976c560774013fd4af8 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh @@ -0,0 +1,168 @@ +#pragma once + +namespace deep_gemm::ptx { + +/// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +/// Tensor memory operations +CUTLASS_DEVICE void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +CUTLASS_DEVICE void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tma.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tma.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1530a3edc57a81e7067a4929f9088929848a8960 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tma.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Tensor-map instructions +CUTLASS_DEVICE void tensor_map_release_gpu() { + asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory"); +} + +CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +/// TMA instructions +CUTLASS_DEVICE void mbarrier_arrive( + cutlass::arch::ClusterTransactionBarrier* ptr) { + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: + "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_arrive_and_set_tx( + cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) { + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: + "r"(num_bytes), "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_wait_and_flip_phase( + cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) { + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: + "r"(static_cast(__cvta_generic_to_shared(ptr))), + "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +CUTLASS_DEVICE void tma_load_1d( + const void* dst_ptr, const void* src_ptr, + cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr, + const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) { + // NOTES: normally, the loaded part will be evicted soon + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" :: + "r"(static_cast(__cvta_generic_to_shared(dst_ptr))), + "l"(src_ptr), + "r"(num_bytes), + "r"(static_cast(__cvta_generic_to_shared(mbarrier_ptr))), + "l"(hint) + : "memory"); +} + +CUTLASS_DEVICE void tma_store_1d( + const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) { + // NOTES: normally, the stored part will be used soon + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" :: + "l"(dst_ptr), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(num_bytes), + "l"(hint) + : "memory"); +} + +template +__forceinline__ __device__ void tma_store_wait() { + // NOTES: this function does not have `.read` + asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory"); +} + +CUTLASS_DEVICE +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier, + void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) { + const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/utils.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5c27166b79ce710bd9eb99354e19fe1e6342dbaa --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/utils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +CUTLASS_DEVICE uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +template +CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) { + DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, ""); + const auto send_int_values = reinterpret_cast(&ptr); + dtype_t recv_dtype; + auto recv_int_values = reinterpret_cast(&recv_dtype); + #pragma unroll + for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast(src_lane_idx)); + return recv_dtype; +} + +CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100 + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast(&b.x))); + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast(&b.y))); +#else + const auto [x, y] = __bfloat1622float2(b); + a.x += x, a.y += y; +#endif +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8912a15766790db8a6fe8ba5a132df61a4958e39 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +CUTLASS_DEVICE void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +} // namespace deep_gemm::ptx diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5cd50c66f6da20a3c3be1d94cbe59757408c7f7b --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh @@ -0,0 +1,300 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sched { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto candidate: {8u, 16u}) { + const auto usage = kIsMulticastOnA ? + candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for contiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = grouped_layout[group_idx]; + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + const uint32_t& shape_k, int* grouped_layout = nullptr) { + num_m_blocks = math::ceil_div(shape_m, BLOCK_M); + num_n_blocks = math::ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = grouped_layout[0]; + num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + // For swap A/B and psum layout only + CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const { + constexpr uint32_t UMMA_STEP_N = 16; + DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment"); + if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) + return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N); + return BLOCK_M; + } + + CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = math::ceil_div(static_cast(grouped_layout[current_group_idx]), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = math::align(current_psum_m, BLOCK_M); + current_psum_m = grouped_layout[current_group_idx]; + current_m_block_cumsum += num_m_blocks; + num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with block M + m_block_idx += last_psum_m / BLOCK_M; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto group_idx = grouped_layout[m_block_idx * BLOCK_M]; + const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M]; + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx]; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return m_offset + m_block_idx * BLOCK_M < current_psum_m; + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm::sched diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cdbecccd560398cc81747e685bddd2d4b3d0ebf0 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::sched { + +// Computation phase for the current block +enum class BlockPhase { + None = 0, + Linear1 = 1, + Linear2 = 2 +}; + +template +struct MegaMoEScheduler { + DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config"); + + // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster + // always land on the same m_block_idx with n_block_idx differing by 1 + DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + // Arrival counts + const layout::Workspace& workspace; + + // Scheduler state + BlockPhase next_phase = BlockPhase::Linear1; + + // Current expert and block indices + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t block_idx = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + + // Pre-cached per-expert token counts (filled during `for_each_block` init) + // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const { + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + // Get pool block offset for a given expert index from a per-lane token count array + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffff, num_blocks); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE bool fetch_next_l1_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + m_block_idx = block_idx / kNumL1BlockNs; + if (m_block_idx < num_m_blocks) + return true; + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL1BlockNs; + advance_expert_idx(); + } + return false; + } + + CUTLASS_DEVICE bool fetch_next_l2_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + if (block_idx < num_m_blocks * kNumL2BlockNs) { + m_block_idx = block_idx / kNumL2BlockNs; + return true; + } + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL2BlockNs; + advance_expert_idx(); + } + return false; + } + + // Core state machine: assigns the next block + CUTLASS_DEVICE cute::tuple get_next_block() { + while (true) { + if (current_local_expert_idx >= kNumExpertsPerRank) + break; + + if (next_phase == BlockPhase::Linear1) { + if (fetch_next_l1_block()) { + // Found a new L1 block + n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // L1 for the current wave is complete, transition to L2 + next_phase = BlockPhase::Linear2; + set_expert_idx(math::align(current_local_expert_idx - 1, kNumExpertsPerWave)); + } + } else { + if (fetch_next_l2_block()) { + // Found a new L2 block + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // Move to L1 of the next wave + next_phase = BlockPhase::Linear1; + } + } + } + + // All waves and experts are fully processed + return {BlockPhase::None, 0, 0, 0}; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + // NOTES: each lane caches experts at indices (i * 32 + lane_idx) + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + // Wait for all expert counters to be finalized + fetch_expert_recv_count(); + + // Initialize current expert with 0 + set_expert_idx(0); + + // Iterate over all blocks + // TODO: add swizzle within expert waves for better L2 cache utilization + while (true) { + CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx); + if (block_phase == BlockPhase::None) + break; + + func(block_phase, current_local_expert_idx, + block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs, + m_block_idx, n_block_idx); + } + } +}; + +} // namespace deep_gemm::sched diff --git a/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..548bbbc6ba59d8abb2c56698908ab0713c1f39cd --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::sched { + +template +CUTLASS_GLOBAL __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } + num_segs[k] = math::ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } +} + +// Conditional storage for varlen indices pointer (EBO: zero cost when unused) +template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template +struct PagedMQALogitsScheduler : IndicesStorage { + const uint32_t* context_lens; + uint32_t batch_size; + + uint32_t current_q_atom_idx, current_kv_idx; + uint32_t end_q_atom_idx, end_kv_idx; + uint32_t current_num_kv; + + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } + } + + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { + this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } + + const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; + const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; + current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_atom_idx); + } + + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_atom_idx = current_q_atom_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + current_kv_idx = 0; + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { + current_num_kv = get_num_kv(current_q_atom_idx); + } + } + return true; + } + + CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const { + return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx); + } +}; + +} // namespace deep_gemm::sched diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp deleted file mode 100644 index 0f074309bb8023014d52c6b3f691450a161367dd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp +++ /dev/null @@ -1,904 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/arch/grid_dependency_control.h" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileSchedulerTag_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileSchedulerTag_, - cute::enable_if_t< - cutlass::detail::is_asymmetric_dma_kernel_tag_of_v || - cutlass::detail::is_asymmetric_dma_kernel_tag_of_v>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; - - using TileSchedulerTag = TileSchedulerTag_; - using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - using TileScheduler = typename detail::TileSchedulerSelector< - TileSchedulerTag, ArchTag, TileShape, ClusterShape - ,TileSchedulerPipelineStageCount - >::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - // Asymmetric buffering - // Tensor A/B could have different buffering, with number of KBLOCK, aka TILEK, - // and STAGEs. It let AsymmetricKRatio, equals KBLOCK_A / KBLOCK_B, to control - // the balance of A/B loading, make sure A/B's pipeline keep same cadence - // when produce / consume data. - // Currently, AsymmetricKRatio = {1, 2} is the only support. - static constexpr bool isAsymmetric = DispatchPolicy::Schedule::isAsymmetric; - static constexpr uint32_t AsymmetricKRatio = isAsymmetric ? 2 : 1; - - // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 8 warps - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp * 2; // 2 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C - - static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; - static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - using TileSchedulerPipeline = typename TileScheduler::Pipeline; - using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState; - using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline; - using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState; - using TileSchedulerStorage = typename TileScheduler::SharedStorage; - - // Kernel level shared memory storage - struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorageMK = typename CollectiveMainloop::PipelineStorageMK; - using MainloopPipelineStorageNK = typename CollectiveMainloop::PipelineStorageNK; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorageMK mainloop_mk; - alignas(16) MainloopPipelineStorageNK mainloop_nk; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - - alignas(16) TileSchedulerStorage scheduler; - - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - EpilogueTensorStorage epilogue; - MainloopTensorStorage mainloop; - } tensors; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm120_smem_capacity_bytes, "SMEM usage exceeded capacity."); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used - // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means - // subtile will not be used, therefore separate reduction will not be enabled. - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles - ); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), - hw_info, - scheduler, - workspace - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - static constexpr uint32_t NumAccumulatorMtxs = 1; - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - - // Preconditions - static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); - static_assert(size<0>(TileShape{}) >= 128, - "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ - enum class WarpGroupRole { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole { - LoadMK = 0, - Warp1 = 1, - LoadNK = 2, - LoadMN = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int mma_thread_idx = thread_idx % NumMMAThreads; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - // TileScheduler pipeline - typename TileSchedulerPipeline::Params scheduler_pipeline_params; - typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params; - if constexpr (IsSchedDynamicPersistent) { - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer; - } - else { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; - } - scheduler_pipeline_params.producer_blockid = 0; - scheduler_pipeline_params.producer_arv_count = 1; - scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + (NumMainloopLoadThreads + NumMMAThreads); - - if (is_epi_load_needed) { - scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; - } - scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse); - - scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - scheduler_throttle_pipeline_params.dst_blockid = 0; - scheduler_throttle_pipeline_params.initializing_warp = 1; - if (warp_group_role == WarpGroupRole::Producer && - producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Consumer; - } - // set role when it is for DMA warp in Mainloop - else if (warp_group_role == WarpGroupRole::Producer && - (producer_warp_role == ProducerWarpRole::LoadMK || - producer_warp_role == ProducerWarpRole::LoadNK)) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Producer; - } - } - TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params, ClusterShape{}); - TileSchedulerPipelineState scheduler_pipe_consumer_state; - - TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params); - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state; - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state(); - - // Mainloop Load pipeline - using MainloopPipelineMK = typename CollectiveMainloop::MainloopPipelineMK; - using MainloopPipelineNK = typename CollectiveMainloop::MainloopPipelineNK; - typename MainloopPipelineMK::Params mainloop_pipeline_params_mk; - typename MainloopPipelineNK::Params mainloop_pipeline_params_nk; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadMK) { - mainloop_pipeline_params_mk.role = MainloopPipelineMK::ThreadCategory::Producer; - mainloop_pipeline_params_mk.is_leader = cute::elect_one_sync(); - mainloop_pipeline_params_mk.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; - } - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadNK) { - mainloop_pipeline_params_nk.role = MainloopPipelineNK::ThreadCategory::Producer; - mainloop_pipeline_params_nk.is_leader = cute::elect_one_sync(); - mainloop_pipeline_params_nk.transaction_bytes = params.mainloop.tma_transaction_bytes_nk; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - mainloop_pipeline_params_mk.role = MainloopPipelineMK::ThreadCategory::Consumer; - mainloop_pipeline_params_nk.role = MainloopPipelineNK::ThreadCategory::Consumer; - } - mainloop_pipeline_params_mk.num_consumers = NumMMAThreads; - mainloop_pipeline_params_nk.num_consumers = NumMMAThreads; - - MainloopPipelineMK mainloop_pipeline_mk(shared_storage.pipelines.mainloop_mk, mainloop_pipeline_params_mk, ClusterShape{}); - MainloopPipelineNK mainloop_pipeline_nk(shared_storage.pipelines.mainloop_nk, mainloop_pipeline_params_nk, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadMN) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; - epi_load_pipeline_params.consumer_arv_count = NumMMAThreads; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - // 2 warps (LoadMK / LoadNK) are ordered before 1 warp (LoadMN) and will signal arrival. - params_load_order_barrier.group_id = ( - producer_warp_role == ProducerWarpRole::LoadMK || - producer_warp_role == ProducerWarpRole::LoadNK) ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp * 2; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineStateMK mainloop_pipe_consumer_state_mk; - typename CollectiveMainloop::PipelineStateNK mainloop_pipe_consumer_state_nk; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - typename CollectiveMainloop::PipelineStateMK mainloop_pipe_producer_state_mk = cutlass::make_producer_start_state(); - typename CollectiveMainloop::PipelineStateNK mainloop_pipe_producer_state_nk = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - TileScheduler scheduler{params.scheduler}; - if constexpr (IsSchedDynamicPersistent) { - scheduler.set_data_ptr(shared_storage.scheduler.data()); - } - // Declare work_tile_info, then define it in each of warps that use it. - typename TileScheduler::WorkTileInfo work_tile_info; - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); - - // Scheduler Producer Warp - if (producer_warp_role == ProducerWarpRole::Warp1) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - if constexpr (IsSchedDynamicPersistent) { - bool requires_clc_query = true; - TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state(); - - cutlass::arch::wait_on_dependent_grids(); - - while (work_tile_info.is_valid()) { - if (requires_clc_query) { - // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. - scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state); - scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state); - ++scheduler_pipe_throttle_consumer_state; - - // Query next clcID and update producer state - scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state); - } - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - work_tile_info = next_work_tile_info; - } - scheduler_pipeline.producer_tail(scheduler_pipe_producer_state); - } - } // Scheduler Producer Warp End - else - // Producer Warp to LoadMK - if (producer_warp_role == ProducerWarpRole::LoadMK) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - bool do_load_order_arrive = true; - bool requires_clc_query = true; - while (work_tile_info.is_valid()) { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (requires_clc_query) { - scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); - scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); - ++scheduler_pipe_throttle_producer_state; - } - - collective_mainloop.load_MK( - params.mainloop, - mainloop_pipeline_mk, - mainloop_pipe_producer_state_mk, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state_mk.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info - ,scheduler_pipeline - ,scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline_mk, mainloop_pipe_producer_state_mk); - - } // Producer Warp LoadMK End - - // LoadNK Producer Warp - if (producer_warp_role == ProducerWarpRole::LoadNK) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - bool do_load_order_arrive = true; - bool requires_clc_query = true; - while (work_tile_info.is_valid()) { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape) * AsymmetricKRatio; - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info) * AsymmetricKRatio; - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (requires_clc_query) { - scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); - scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); - ++scheduler_pipe_throttle_producer_state; - } - - collective_mainloop.load_NK( - params.mainloop, - mainloop_pipeline_nk, - mainloop_pipe_producer_state_nk, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state_nk.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info - ,scheduler_pipeline - ,scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline_nk, mainloop_pipe_producer_state_nk); - - } // Producer Warp LoadNK End - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::LoadMN && - is_epi_load_needed) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) { - load_order_barrier.wait(); - } - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - while (work_tile_info.is_valid()) { - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state = - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx() - ); - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info - ,scheduler_pipeline - ,scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Producer Warp LoadMN End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - cutlass::arch::warpgroup_reg_alloc(); - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - - collective_mainloop.mma( - mainloop_pipeline_mk, - mainloop_pipe_consumer_state_mk, - mainloop_pipeline_nk, - mainloop_pipe_consumer_state_nk, - accumulators, - k_tile_iter, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop, - blk_coord, - problem_shape_MNKL - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline_mk, - mainloop_pipe_consumer_state_mk, - mainloop_pipeline_nk, - mainloop_pipe_consumer_state_nk, - work_k_tile_count - ); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state_mk.advance(work_k_tile_count); - mainloop_pipe_consumer_state_nk.advance(work_k_tile_count * AsymmetricKRatio); - } - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - } - #endif - - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx() - ); - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info - ,scheduler_pipeline - ,scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state - ); - } - } // Consumer Warp Groups End - } - -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp deleted file mode 100644 index 18c79608ad6ba821f84ebc3ef717eddad7fd682c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ /dev/null @@ -1,270 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/tensor.hpp" - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, - cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - static constexpr bool IsGdcEnabled = false; - - static constexpr bool is_valid_tile_scheduler = - cute::is_void_v or cute::is_same_v; -static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler."); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); - - // MSVC requires the cast to fix a warning-as-error. - static constexpr int SharedStorageSize = static_cast(cute::max( - sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage))); - - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{})); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - - KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count}; - auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); - - return { - args.mode, - args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) - }; - } - - static bool - can_implement(Arguments const& args) { - bool mode_implementable = args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); - return mode_implementable && TileScheduler::can_implement(args.scheduler); - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - return workspace_size; - } - - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - cutlass::Status status = Status::kSuccess; - - return status; - } - - static dim3 - get_grid_shape(Params const& params) { - int batch_count = 1; - if constexpr (cute::rank(ProblemShape{}) == 4) { - batch_count = cute::size<3>(params.problem_shape); - } - - return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), - batch_count - ); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - - // Preconditions - CUTE_STATIC_ASSERT(is_static::value); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto [M,N,K,L] = problem_shape_MNKL; - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - int thread_idx = int(threadIdx.x); - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); - auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) - - // Represent the full tensors - Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) - - // Get batch slice - Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) - Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) - - // Slice to get the tiles this thread block is responsible for - Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); - - // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape - TiledMma tiled_mma; - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - int k_tile_count = size<2>(gA); - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - collective_mma( - accumulators, - gA, - gB, - accumulators, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - smem_buf - ); - // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( - problem_shape_MNKL, - blk_shape, - blk_coord_mnkl, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - smem_buf - ); - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp deleted file mode 100644 index c0ef53a7e86239cb10bb9c57dde255199fe8e3fc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp +++ /dev/null @@ -1,279 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/tensor.hpp" - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, - "ProblemShape{} should be or "); - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using InternalStrideA = typename CollectiveMainloop::InternalStrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using InternalStrideB = typename CollectiveMainloop::InternalStrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, - cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - static constexpr bool IsGdcEnabled = false; - - static constexpr bool is_valid_tile_scheduler = - cute::is_void_v or cute::is_same_v; -static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler."); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); - - // MSVC requires the cast to fix a warning-as-error. - static constexpr int SharedStorageSize = static_cast(cute::max( - sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage))); - - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{})); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - typename ProblemShape::UnderlyingProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); - - KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count}; - auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(problem_shape, args.epilogue, workspace) - }; - } - - static bool - can_implement(Arguments const& args) { - - bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - return workspace_size; - } - - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - cutlass::Status status = Status::kSuccess; - - return status; - } - - static dim3 - get_grid_shape(Params const& params) { - int batch_count = cute::size<3>(params.problem_shape); - return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), - batch_count - ); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - - // Preconditions - CUTE_STATIC_ASSERT(is_static::value); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto [M,N,K,L] = problem_shape_MNKL; - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - int thread_idx = int(threadIdx.x); - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); - auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) - - // Represent the full tensors - Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B[l_coord]), make_shape(N,K,1), params.mainloop.dB); //(n,k,l) - - // Get batch slice - Tensor mA_mk = mA_mkl(_,_,0); // (m,k) - Tensor mB_nk = mB_nkl(_,_,0); // (n,k) - - // Slice to get the tiles this thread block is responsible for - Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); - - // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape - TiledMma tiled_mma; - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - int k_tile_count = size<2>(gA); - - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - collective_mma( - accumulators, - gA, - gB, - accumulators, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - smem_buf - ); - - // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( - problem_shape_MNKL, - blk_shape, - blk_coord_mnkl, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - smem_buf - ); - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp deleted file mode 100644 index ec5cd4d0584a73825f5cb7dd909a7774463e1a2d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ /dev/null @@ -1,1039 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t> -> -{ - // Get the type of the scheduler response. - template - struct TileSchedulerResponseGetter { - using Type = typename TileScheduler::CLCResponse; - }; - - template - struct TileSchedulerResponseGetter> { - using Type = typename TileScheduler::SchedulerResponse; - }; - -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, - "ProblemShape{} should be or "); - - static_assert(cute::is_base_of_v); - - static constexpr bool IsGdcEnabled = false; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using InternalStrideA = typename CollectiveMainloop::InternalStrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using InternalStrideB = typename CollectiveMainloop::InternalStrideB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using Schedule = typename DispatchPolicy::Schedule; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; - static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; - - static_assert( - cute::is_void_v - or ( - IsGroupedGemmKernel - and cute::is_any_of_v - ), - "Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler."); - - using SchedulerTag = cute::conditional_t< - cute::is_void_v, - cute::conditional_t< - IsGroupedGemmKernel, - GroupScheduler, // Special grouped gemm scheduler - void // Default scheduler for non-grouped kernels - >, - TileScheduler_ - >; - - using TileScheduler = typename detail::TileSchedulerSelector< - SchedulerTag, - ArchTag, - TileShape, - ClusterShape, - 8, // SchedulerPipelineStageCount -- Grouped GEMM scheduler will benefit from a larger number of stages. - cute::conditional_t, void, ProblemShape> // Use void for default scheduler. - >::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - using TileSchedulerResponse = typename TileSchedulerResponseGetter::Type; - - static constexpr auto TileSchedulerStages = 8; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaThreads = size(TiledMma{}); - static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; - static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using TileSchedulerPipelineStorage = typename TileScheduler::PipelineStorage; - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) TileSchedulerPipelineStorage scheduler; - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - - alignas(16) TileSchedulerResponse scheduler_response[TileSchedulerStages]; - - struct TensorMapStorage : cute::aligned_struct<128, _1> { - using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; - using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; - - alignas(128) MainloopTensorMapStorage mainloop; - alignas(128) EpilogueTensorMapStorage epilogue; - } tensormaps; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - ProblemShape problem_shapes = args.problem_shape; - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - void* mainloop_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - TileSchedulerParams scheduler; - if constexpr (IsGroupedGemmKernel) { - scheduler = TileScheduler::to_underlying_arguments( - problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace); - } - else { - scheduler = TileScheduler::to_underlying_arguments( - problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace); - } - - return { - args.mode, - problem_shapes, - CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), - hw_info, - scheduler, - workspace - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = true; - if constexpr (IsGroupedGemmKernel) { - // Group GEMM currently only supports rank-3 problem shapes - implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); - } - else { - implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); - } - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - static constexpr uint32_t NumAccumulatorMtxs = 1; - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - dim3 grid_shape; - if constexpr (IsGroupedGemmKernel) { - grid_shape = TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - else { - grid_shape = TileScheduler::get_grid_shape(params.scheduler, params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); - } - return grid_shape; - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ - CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -# endif - -// Any Tensor Op MMA Atom in the ISA is arch conditional. -#if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); - static_assert(size<0>(TileShape{}) >= 128, - "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - static_assert(NumMmaWarpGroups == 2, "Cooperative kernels currently only support NumMmaWarpGroups == 2"); - - if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { - static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, - "Tiled MmA does not match expected warp groups performing the epilogue"); - } - - static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ - enum class WarpGroupRole { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole { - Mainloop = 0, - MainloopAux = 1, - Epilogue = 2, - Scheduler = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - auto scheduler = [&] () { - // Group scheduler requires a different constructor that takes a response ptr - if constexpr (cute::is_same_v) { - return TileScheduler{params.scheduler, shared_storage.scheduler_response}; - } - else { - return TileScheduler{params.scheduler}; - } - } (); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_idx = canonical_warp_group_idx(); - auto warp_group_role = WarpGroupRole(warp_group_idx); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Note: Tma Descriptor Prefetch (from either const or param) is not applicable here - - // TileScheduler pipeline - using TileSchedulerPipeline = typename TileScheduler::Pipeline; - typename TileSchedulerPipeline::Params tile_scheduler_pipeline_params; - if constexpr (cute::is_same_v) { - if (warp_group_role == WarpGroupRole::Producer - && producer_warp_role == ProducerWarpRole::Scheduler) { - tile_scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Producer; - } - else { - tile_scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; - } - tile_scheduler_pipeline_params.consumer_arv_count = NumMmaThreads - + NumThreadsPerWarp * ( - 1 // Main DMA warp - + (collective_epilogue.is_producer_load_needed() ? 1 : 0) // Epilog DMA warp - + (IsMainloopAuxiliaryLoadNeeded ? 1 : 0) // Aux DMA warp - ); - tile_scheduler_pipeline_params.producer_arv_count = 1; - } - TileSchedulerPipeline tile_scheduler_pipeline(shared_storage.pipelines.scheduler, tile_scheduler_pipeline_params); - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer - && (producer_warp_role == ProducerWarpRole::Mainloop - || producer_warp_role == ProducerWarpRole::MainloopAux)) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumMmaThreads; - mainloop_pipeline_params.num_producers = NumProducerThreads; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename TileSchedulerPipeline::PipelineState tile_scheduler_pipe_consumer_state; - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState tile_scheduler_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - if (not work_tile_info.is_valid()) { - // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups - return; - } - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - - if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); - - if (producer_warp_role == ProducerWarpRole::Scheduler) { - // GroupScheduler requires a producer warp to iterate over the group infos and push - // the work tile infos to the downstream pipelines. - if constexpr (cute::is_same_v) { - do { - auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(tile_scheduler_pipeline, tile_scheduler_pipe_producer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_producer_state; - } - } while (work_tile_info.is_valid()); - tile_scheduler_pipeline.producer_tail(tile_scheduler_pipe_producer_state); - } - } - // Mainloop Producer Warp - else if (producer_warp_role == ProducerWarpRole::Mainloop) { - int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; - int32_t const mock_l_coord = 0; - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - - // Fetch a copy of tensormaps for the CTA - auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); - - // Update tensormap for the initial batch for the CTA - collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, - params.mainloop, - input_tensormaps, - problem_shape_MNKL, - curr_batch - ); - // Ensure warp is converged before issuing tensormap fence release - __syncwarp(); - // Entire warp must do this (i.e. it's aligned) - collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); - - bool do_load_order_arrive = true; - bool did_batch_change = true; - do { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (did_batch_change) { - load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); - collective_mainloop.tensormaps_fence_acquire(input_tensormaps); - } - - collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - input_tensormaps, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Pipeline state is only advanced if there are K tiles to compute - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx - did_batch_change = next_batch != curr_batch; - if (work_tile_info.is_valid() && did_batch_change) { - curr_batch = next_batch; - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); - } - collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, - params.mainloop, - input_tensormaps, - problem_shape_MNKL, - curr_batch - ); - // Ensure warp is converged before issuing tensor replace - __syncwarp(); - // Entire warp must do this (i.e. it's aligned) - collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); - } - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - else if (producer_warp_role == ProducerWarpRole::MainloopAux) { - if constexpr (IsMainloopAuxiliaryLoadNeeded) { - int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; - int32_t const mock_l_coord = 0; - - bool did_batch_change = true; - do { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (did_batch_change) { - load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); - } - - collective_mainloop.load_auxiliary( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx - did_batch_change = next_batch != curr_batch; - if (work_tile_info.is_valid() && did_batch_change) { - curr_batch = next_batch; - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); - } - } - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - } // End of auxiliary load needed check - } // Mainloop Auxiliary Load Producer Warp End - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - - auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); - - bool did_batch_change = true; - constexpr bool IsEpiLoad = true; - - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - 0 - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); - - load_order_barrier.wait(); - - do { - int32_t curr_batch = work_tile_info.L_idx; - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - if (did_batch_change) { - collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); - } - - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue, - epi_load_tensormap, - work_tile_info.reduction_subtile_idx() - ); - } - - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - did_batch_change = curr_batch != work_tile_info.L_idx; - - if (work_tile_info.is_valid() && did_batch_change) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - // tensormap update - { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - 0 - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); - } - } - - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - cutlass::arch::warpgroup_reg_alloc(); - - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; - - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - // Get a copy of tensormaps - auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); - - bool did_batch_change = true; - constexpr bool IsEpiLoad = false; - - if (warp_idx_in_warp_group == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - consumer_warp_group_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, - consumer_warp_group_idx); - } - - do { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - int32_t curr_batch = work_tile_info.L_idx; - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - work_k_tile_count - ); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - } - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - - if (did_batch_change) { - collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); - } - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue, - epi_store_tensormap, - work_tile_info.reduction_subtile_idx() - ); - - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - - did_batch_change = curr_batch != work_tile_info.L_idx; - if (work_tile_info.is_valid() && did_batch_change) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - if (warp_idx_in_warp_group == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - consumer_warp_group_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, - consumer_warp_group_idx); - } - } - - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - - // Cooperative only needs TMA to complete at the very end of the kernel - if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state - ); - } - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp deleted file mode 100644 index fd7ff603b8f17347767ee746d9cd29bd5ed81bf2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ /dev/null @@ -1,1110 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t> -> -{ - // Get the type of the scheduler response. - template - struct TileSchedulerResponseGetter { - using Type = typename TileScheduler::CLCResponse; - }; - - template - struct TileSchedulerResponseGetter> { - using Type = typename TileScheduler::SchedulerResponse; - }; - -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, - "ProblemShape{} should be or "); - - static_assert(cute::is_base_of_v); - - static constexpr bool IsGdcEnabled = false; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using InternalStrideA = typename CollectiveMainloop::InternalStrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using InternalStrideB = typename CollectiveMainloop::InternalStrideB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using Schedule = typename DispatchPolicy::Schedule; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; - static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; - - static_assert( - cute::is_void_v - or ( - IsGroupedGemmKernel - and cute::is_any_of_v - ), - "Ptr-Array Pingpong and Grouped Gemm Pingpong kernel only supports the default scheduler."); - - using SchedulerTag = cute::conditional_t< - cute::is_void_v, - cute::conditional_t< - IsGroupedGemmKernel, - GroupScheduler, // Special grouped gemm scheduler - void // Default scheduler for non-grouped kernels - >, - TileScheduler_ - >; - - using TileScheduler = typename detail::TileSchedulerSelector< - SchedulerTag, - ArchTag, - TileShape, - ClusterShape, - 8, // SchedulerPipelineStageCount -- Grouped GEMM scheduler will benefit from a larger number of stages. - cute::conditional_t, void, ProblemShape> // Use void for default scheduler. - >::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - using TileSchedulerResponse = typename TileSchedulerResponseGetter::Type; - - static constexpr auto TileSchedulerStages = 8; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; - static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue - static constexpr uint32_t StagesPerMathWarpGroup = 2; - using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; - using MathWarpGroupOrderBarrierSharedStorage = cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage< - MathWarpGroupOrderBarrier::SequenceDepth, - MathWarpGroupOrderBarrier::SequenceLength>; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using TileSchedulerPipelineStorage = typename TileScheduler::PipelineStorage; - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage; - - alignas(16) TileSchedulerPipelineStorage scheduler; - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; - } pipelines; - - alignas(16) TileSchedulerResponse scheduler_response[TileSchedulerStages]; - - struct TensorMapStorage : cute::aligned_struct<128, _1> { - using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; - using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; - - alignas(128) MainloopTensorMapStorage mainloop; - alignas(128) EpilogueTensorMapStorage epilogue; - } tensormaps; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - ProblemShape problem_shapes = args.problem_shape; - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - void* mainloop_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - - // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used - // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means - // subtile will not be used, therefore separate reduction will not be enabled. - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler; - if constexpr (IsGroupedGemmKernel) { - scheduler = TileScheduler::to_underlying_arguments( - problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); - } - else { - scheduler = TileScheduler::to_underlying_arguments( - problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); - } - - return { - args.mode, - problem_shapes, - CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), - hw_info, - scheduler, - workspace - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = true; - if constexpr (IsGroupedGemmKernel) { - // Group GEMM currently only supports rank-3 problem shapes - implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); - } - else { - implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); - } - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - static constexpr uint32_t NumAccumulatorMtxs = 1; - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - dim3 grid_shape; - if constexpr (IsGroupedGemmKernel) { - grid_shape = TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - else { - grid_shape = TileScheduler::get_grid_shape(params.scheduler, params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); - } - return grid_shape; - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ - CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -# endif - -// Any Tensor Op MMA Atom in the ISA is arch conditional. -#if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(size(TiledMma{}) == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); - static_assert(NumMmaWarpGroups == 2, "Pingpong kernels currently only support NumMmaWarpGroups == 2"); - - if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { - static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, - "Tiled MmA does not match expected warp groups performing the epilogue"); - } - - static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - enum class WarpGroupRole { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole { - Mainloop = 0, - MainloopAux = 1, - Epilogue = 2, - Scheduler = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - auto scheduler = [&] () { - // Group scheduler requires a different constructor that takes a response ptr - if constexpr (cute::is_same_v) { - return TileScheduler{params.scheduler, shared_storage.scheduler_response}; - } - else { - return TileScheduler{params.scheduler}; - } - } (); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_idx = canonical_warp_group_idx(); - auto warp_group_role = WarpGroupRole(warp_group_idx); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Note: Tma Descriptor Prefetch (from either const or param) is not applicable here - - // TileScheduler pipeline - using TileSchedulerPipeline = typename TileScheduler::Pipeline; - typename TileSchedulerPipeline::Params tile_scheduler_pipeline_params; - if constexpr (cute::is_same_v) { - if (warp_group_role == WarpGroupRole::Producer - && producer_warp_role == ProducerWarpRole::Scheduler) { - tile_scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Producer; - } - else { - tile_scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; - } - tile_scheduler_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup * NumMmaWarpGroups // 1 MATH WG - + NumThreadsPerWarp * ( - 1 // Main DMA warp - + (collective_epilogue.is_producer_load_needed() ? 1 : 0) // Epilog DMA warp - + (IsMainloopAuxiliaryLoadNeeded ? 1 : 0) // Aux DMA warp - ); - tile_scheduler_pipeline_params.producer_arv_count = 1; - } - TileSchedulerPipeline tile_scheduler_pipeline(shared_storage.pipelines.scheduler, tile_scheduler_pipeline_params); - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer - && (producer_warp_role == ProducerWarpRole::Mainloop - || producer_warp_role == ProducerWarpRole::MainloopAux)) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.num_producers = NumProducerThreads; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA Load WG will not participate in these Ordered Barrier syncs - params_math_wg_order_barrier.group_id = warp_group_idx - static_cast(WarpGroupRole::Consumer0); - params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename TileSchedulerPipeline::PipelineState tile_scheduler_pipe_consumer_state; - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState tile_scheduler_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - if (not work_tile_info.is_valid()) { - // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups - return; - } - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - - // Consumer1 is not on the critical path at prologue. - if (warp_group_role == WarpGroupRole::Consumer1) [[unlikely]] { - // Advance 2nd Math WG to the next work tile for the startup - const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (!work_tile_info.is_valid()) { - return; - } - - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - - // Advance 2nd Math WG pipeline states to the end of 1st Math WG - mainloop_pipe_consumer_state.advance(k_tile_count); - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - - if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); - - if (producer_warp_role == ProducerWarpRole::Scheduler) { - // GroupScheduler requires a producer warp to iterate over the group infos and push - // the work tile infos to the downstream pipelines. - if constexpr (cute::is_same_v) { - do { - auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(tile_scheduler_pipeline, tile_scheduler_pipe_producer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_producer_state; - } - } while (work_tile_info.is_valid()); - tile_scheduler_pipeline.producer_tail(tile_scheduler_pipe_producer_state); - } - } - // Mainloop Producer Warp - else if (producer_warp_role == ProducerWarpRole::Mainloop) { - int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; - int32_t const mock_l_coord = 0; - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - - // Fetch a copy of tensormaps for the CTA - auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); - - // Update tensormap for the initial batch for the CTA - collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, - params.mainloop, - input_tensormaps, - problem_shape_MNKL, - curr_batch - ); - // Ensure warp is converged before issuing tensormap fence release - __syncwarp(); - // Entire warp must do this (i.e. it's aligned) - collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); - - bool do_load_order_arrive = true; - bool did_batch_change = true; - do { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (did_batch_change) { - load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); - collective_mainloop.tensormaps_fence_acquire(input_tensormaps); - } - - collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - input_tensormaps, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Pipeline state is only advanced if there are K tiles to compute - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx - did_batch_change = next_batch != curr_batch; - if (work_tile_info.is_valid() && did_batch_change) { - curr_batch = next_batch; - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); - } - collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, - params.mainloop, - input_tensormaps, - problem_shape_MNKL, - curr_batch - ); - // Ensure warp is converged before issuing tensor replace - __syncwarp(); - // Entire warp must do this (i.e. it's aligned) - collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); - } - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - else if (producer_warp_role == ProducerWarpRole::MainloopAux) { - if constexpr (IsMainloopAuxiliaryLoadNeeded) { - int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; - int32_t const mock_l_coord = 0; - - bool did_batch_change = true; - do { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (did_batch_change) { - load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); - } - - collective_mainloop.load_auxiliary( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx - did_batch_change = next_batch != curr_batch; - if (work_tile_info.is_valid() && did_batch_change) { - curr_batch = next_batch; - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); - } - } - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - } // End of auxiliary load needed check - } // Mainloop Auxiliary Load Producer Warp End - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - - auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); - - bool did_batch_change = true; - constexpr bool IsEpiLoad = true; - - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - 0 - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); - - load_order_barrier.wait(); - - do { - int32_t curr_batch = work_tile_info.L_idx; - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - if (did_batch_change) { - collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); - } - - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue, - epi_load_tensormap, - work_tile_info.reduction_subtile_idx() - ); - } - - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - did_batch_change = curr_batch != work_tile_info.L_idx; - - if (work_tile_info.is_valid() && did_batch_change) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - // tensormap update - { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - 0 - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); - } - } - - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - cutlass::arch::warpgroup_reg_alloc(); - - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; - - int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); - int32_t const sm_count = params.hw_info.sm_count; - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - // Get a copy of tensormaps - auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); - - bool did_batch_change = true; - constexpr bool IsEpiLoad = false; - - if (warp_idx_in_warp_group == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - consumer_warp_group_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, - consumer_warp_group_idx); - } - - do { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - - int32_t curr_batch = work_tile_info.L_idx; - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - - math_wg_order_barrier.wait(); - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - math_wg_order_barrier.arrive(); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - work_k_tile_count - ); - - math_wg_order_barrier.wait(); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - } - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - - if (did_batch_change) { - collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); - } - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue, - epi_store_tensormap, - work_tile_info.reduction_subtile_idx() - ); - - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - - // Skip a tile for pingpong - if (work_tile_info.is_valid()) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - mainloop_pipe_consumer_state.advance(work_k_tile_count); - - // Go to next tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++tile_scheduler_pipe_consumer_state; - } - } - - did_batch_change = curr_batch != work_tile_info.L_idx; - if (work_tile_info.is_valid() && did_batch_change) { - if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); - } - if (warp_idx_in_warp_group == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - problem_shape_MNKL, - work_tile_info.L_idx, - consumer_warp_group_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, - consumer_warp_group_idx); - } - } - - // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels - // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives - // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. - auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state - ); - - // Update starting load/store pipeline states for the next tile - // state has already been incremented by 1 tile in collective calls, advance once again for ping pong - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - // Cue for next Math WG's Epilogue to start - math_wg_order_barrier.arrive(); - - } while (work_tile_info.is_valid()); // Scheduler work fetch loop - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp deleted file mode 100644 index 2292d7e4a2d0f0355e62fb338023beaba370d0cf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ /dev/null @@ -1,306 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/trace.h" -#include "cute/tensor.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - static constexpr bool IsGdcEnabled = false; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); - - static_assert(cute::is_void_v or cute::is_same_v, - "TMA kernel does not support specializing the tile scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - - static constexpr int SharedStorageSize = static_cast(cute::max( - sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage))); - - static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::ThreadCount; - - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - auto cluster_shape = ClusterShape{}; - auto tile_shape = TileShape{}; - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - return TileScheduler::get_tiled_cta_shape_mnl( - problem_shape_MNKL, tile_shape, cluster_shape); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - int thread_idx = int(threadIdx.x); - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - } - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) - - // Get the appropriate blocks for this thread block -- potential for thread block locality - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice - - // Make tiled views - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - // Compute m_coord, n_coord, and l_coord with their post-tiled shapes - auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); - auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); - auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with m_coord and n_coord - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape - TiledMma tiled_mma; - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - auto k_tile_count = size<2>(gA); - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - collective_mma( - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - accumulators, - k_tile_iter, k_tile_count, - thread_idx, - block_rank_in_cluster, - smem_buf, - params.mainloop - ); - - constexpr int BLK_M_RANK = cute::rank<0>(blk_shape); - auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get(M) - get<0,i>(blk_shape) * get(m_coord); - })); - - constexpr int BLK_N_RANK = cute::rank<1>(blk_shape); - auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return get(N) - get<1,i>(blk_shape) * get(n_coord); - })); - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); - - // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( - problem_shape_MNKL, - blk_shape, - output_tile_coord, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - smem_buf - ); -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp deleted file mode 100644 index 5b558005f315e4b1a8143b67096931ed16ad490c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ /dev/null @@ -1,522 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" - -#include "cutlass/conv/detail.hpp" - -#include "cute/tensor.hpp" -#include "cute/arch/cluster_sm90.hpp" - -#include "cutlass/arch/grid_dependency_control.h" - - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t> -> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - - // Handles the static_assert placed inside the operator() - // This is also used to decide whether the load_init inside collective mainloop returns rank 4 tensors or rank 5 tensors - static constexpr bool IsConvProblemShape = not (cute::is_tuple_v|| IsCutlass3ArrayKernel::value); - static_assert( IsConvProblemShape || (cute::rank(ProblemShape{}) == 3 || cute::rank(ProblemShape{}) == 4), "ProblemShape{} should be or for Gemm"); - - static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(cute::is_void_v or cute::is_same_v, - "TMA warp-specialized kernel does not support specializing the tile scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileSchedulerTag, ArchTag, TileShape, ClusterShape>::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - - // Kernel level shared memory storage - struct SharedStorage { - // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union - union TensorStorage { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 1; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Device side arguments - struct Arguments { - cutlass::gemm::GemmUniversalMode mode{}; //maintained here for backward compatibility - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - - // Default constructor - Arguments() = default; - - // Constructor with specified mode - // It is used for Gemm - Arguments( - cutlass::gemm::GemmUniversalMode mode_, - ProblemShape problem_shape_, - MainloopArguments mainloop_, - EpilogueArguments epilogue_, - KernelHardwareInfo hw_info_ = KernelHardwareInfo(), - TileSchedulerArguments scheduler_ = TileSchedulerArguments()) - : mode(mode_) - , problem_shape(problem_shape_) - , mainloop(mainloop_) - , epilogue(epilogue_) - , hw_info(hw_info_) - , scheduler(scheduler_) {} - - // Constructor with default value for 'mode' - // This allows us to set GemmUniversal mode as kGemm for Conv right away - // while keeping the testbeds unchanged - Arguments( - ProblemShape problem_shape_, - MainloopArguments mainloop_, - EpilogueArguments epilogue_, - KernelHardwareInfo hw_info_ = KernelHardwareInfo(), - TileSchedulerArguments scheduler_ = TileSchedulerArguments()) - : mode(cutlass::gemm::GemmUniversalMode::kGemm) // Default mode - , problem_shape(problem_shape_) - , mainloop(mainloop_) - , epilogue(epilogue_) - , hw_info(hw_info_) - , scheduler(scheduler_) {} - - }; - - // Kernel entry point API - struct Params { - using ProblemShapeMNKL = decltype(cutlass::conv::detail::get_problem_shape_MNKL_helper(ProblemShape{}, cute::conditional_t{})); - ProblemShapeMNKL problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - - (void) workspace; - auto problem_shape_mnkl = cutlass::conv::detail::get_problem_shape_MNKL_helper(args.problem_shape, cute::conditional_t{}); - auto transformed_problem_shape = cutlass::conv::detail::get_transformed_problem_shape_MNKL(args.problem_shape); - - auto swapped_problem_shape = problem_shape_mnkl; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(swapped_problem_shape) = get<1>(problem_shape_mnkl); - get<1>(swapped_problem_shape) = get<0>(problem_shape_mnkl); - } - return { - swapped_problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(transformed_problem_shape, args.epilogue, workspace) - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = true; - auto transformed_problem_shape = cutlass::conv::detail::get_transformed_problem_shape_MNKL(args.problem_shape); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(transformed_problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - auto cluster_shape = ClusterShape{}; - auto tile_shape = TileShape{}; - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - return TileScheduler::get_tiled_cta_shape_mnl( - problem_shape_MNKL, tile_shape, cluster_shape); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ - CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -# endif - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - enum class ProducerWarpRole { - MainloopEpilogue = 0, - Warp1 = 1, - Warp2 = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Preconditions only valid for Gemm - static_assert(IsConvProblemShape || cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(IsConvProblemShape || cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(IsConvProblemShape || cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(IsConvProblemShape || cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - TiledMma tiled_mma; - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - // Using constexpr if (C++17 and later) - auto problem_shape_MNKL = append<4>(params.problem_shape, cute::Int<1>{}); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. - // Expects a tuple of tensors for conv where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Compute m_coord, n_coord, and l_coord with their post-tiled shapes - auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); - // handles the difference between the rank of Tensor returned by load_input in case they do not have a batch mode - auto l_coord = [&] (auto const& gB_nkl_) { - // gB_nkl needs to be passed into the lambda because C++17 - // does not permit lambda capture of structured bindings. - if constexpr (not IsConvProblemShape) { - // This needs to be inside an `if constexpr`, - // because shape<4>(gB_nkl) is not well-formed otherwise. - return idx2crd(int(blockIdx.z), shape<4>(gB_nkl_)); - } - else { - return Int<0>{}; - } - } (gB_nkl); - - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get pipeline iterators and increments from tensor shapes - auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - auto k_tile_count = size<3>(gA_mkl); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - // Ensure warp is converged before issuing epilogue loads - __syncwarp(); - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue - ); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } - } - else if (warp_group_role == WarpGroupRole::Consumer) { - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - warp_group_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state_next, - epi_store_pipeline, - epi_store_pipe_producer_state_next - ); - } -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp deleted file mode 100644 index d398d1f2906c473453f774e528adde246e953620..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ /dev/null @@ -1,861 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/arch/grid_dependency_control.h" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileSchedulerTag_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileSchedulerTag_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; - using TileSchedulerTag = TileSchedulerTag_; - - using TileScheduler = typename detail::TileSchedulerSelector< - TileSchedulerTag, - ArchTag, - TileShape, - ClusterShape - ,TileSchedulerPipelineStageCount - >::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 8 warps - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C - - static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; - static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; - static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; - static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; - - /// Register requirement for Load and Math WGs - static constexpr int RegsPerThread = - size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * - sizeof(ElementAccumulator) / sizeof(uint32_t); - static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; - static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; - static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - using TileSchedulerPipeline = typename TileScheduler::Pipeline; - using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState; - using TileSchedulerStorage = typename TileScheduler::SharedStorage; - using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline; - using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState; - - // Kernel level shared memory storage - struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - - alignas(16) TileSchedulerStorage scheduler; - - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - EpilogueTensorStorage epilogue; - MainloopTensorStorage mainloop; - } tensors; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used - // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means - // subtile will not be used, therefore separate reduction will not be enabled. - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles - ); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), - hw_info, - scheduler, - workspace - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - return workspace_size; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - static constexpr uint32_t NumAccumulatorMtxs = 1; - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ - CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -# endif - -// Any Tensor Op MMA Atom in the ISA is arch conditional. -#if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(NumMMAThreads == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); - static_assert(size<0>(TileShape{}) >= 128, - "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ - enum class WarpGroupRole { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - MainloopAux = 3 - }; - - - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int mma_thread_idx = thread_idx % NumMMAThreads; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - // TileScheduler pipeline - typename TileSchedulerPipeline::Params scheduler_pipeline_params; - typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params; - if constexpr (IsSchedDynamicPersistent) { - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer; - } - else { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; - } - scheduler_pipeline_params.producer_blockid = 0; - scheduler_pipeline_params.producer_arv_count = 1; - scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + NumMainloopLoadThreads + NumMMAThreads; - - if (is_epi_load_needed) { - scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; - } - scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse); - - scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - scheduler_throttle_pipeline_params.dst_blockid = 0; - scheduler_throttle_pipeline_params.initializing_warp = 3; - if (warp_group_role == WarpGroupRole::Producer && - producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Consumer; - } - // set role when it is for DMA warp in Mainloop - else if (warp_group_role == WarpGroupRole::Producer && - producer_warp_role == ProducerWarpRole::Mainloop) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Producer; - } - } - TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params); - TileSchedulerPipelineState scheduler_pipe_consumer_state; - - TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params); - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state; - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state(); - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop || - producer_warp_role == ProducerWarpRole::MainloopAux)) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumMMAThreads; - mainloop_pipeline_params.num_producers = NumProducerThreads; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; - epi_load_pipeline_params.consumer_arv_count = NumMMAThreads; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - - auto cluster_wait_fn = [] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - TileScheduler scheduler{params.scheduler}; - if constexpr (IsSchedDynamicPersistent) { - scheduler.set_data_ptr(shared_storage.scheduler.data()); - } - // Declare work_tile_info, then define it in each of warps that use it. - typename TileScheduler::WorkTileInfo work_tile_info; - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - cutlass::arch::warpgroup_reg_dealloc(); - - // Scheduler Producer Warp - if (producer_warp_role == ProducerWarpRole::Warp1) { - if constexpr (IsSchedDynamicPersistent) { - bool requires_clc_query = true; - TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state(); - - cutlass::arch::wait_on_dependent_grids(); - while (work_tile_info.is_valid()) { - - if (requires_clc_query) { - // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. - scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state); - scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state); - ++scheduler_pipe_throttle_consumer_state; - - // Query next work tile - scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state); - } - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - - work_tile_info = next_work_tile_info; - } - scheduler_pipeline.producer_tail(scheduler_pipe_producer_state); - } - } // Scheduler Producer Warp End - else - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - bool do_load_order_arrive = true; - bool requires_clc_query = true; - while (work_tile_info.is_valid()) { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - if (requires_clc_query) { - scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); - scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); - ++scheduler_pipe_throttle_producer_state; - } - - collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - } - else if (producer_warp_role == ProducerWarpRole::MainloopAux) { - if constexpr (IsMainloopAuxiliaryLoadNeeded) { - while (work_tile_info.is_valid()) { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - collective_mainloop.load_auxiliary( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, work_k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - - work_tile_info = next_work_tile_info; - } // Scheduler work fetch loop - - } - } - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && is_epi_load_needed) { - - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) { - load_order_barrier.wait(); - } - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - while (work_tile_info.is_valid()) { - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state = - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx() - ); - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - cutlass::arch::warpgroup_reg_alloc(); - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - work_k_tile_count - ); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - } - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - } - #endif - - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx() - ); - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - if constexpr (IsSchedDynamicPersistent) { - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - } // Scheduler work fetch loop - - if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state - ); - } - } // Consumer Warp Groups End -#endif - } - -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp deleted file mode 100644 index 1326f390fdcd536cec9f74bd8c311342ef2d53de..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ /dev/null @@ -1,946 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/fast_math.h" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" - -#include "cute/tensor.hpp" -#include "cutlass/arch/grid_dependency_control.h" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(!cute::is_same_v, "Ping-pong kernel does not currently support stream-K scheduler."); - static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileSchedulerTag, - ArchTag, - TileShape, - ClusterShape, - TileSchedulerPipelineStageCount - >::Scheduler; - - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - using TileSchedulerPipeline = typename TileScheduler::Pipeline; - using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState; - using TileSchedulerStorage = typename TileScheduler::SharedStorage; - - using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline; - using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState; - - static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; - - // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; - static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp - static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; - - static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); - static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total."); - - /// Register requirement for Load and Math WGs - static constexpr int RegsPerThread = - (size<0>(TileShape{}) * size<1>(TileShape{}) * sizeof(ElementAccumulator)) - / (NumMMAThreads * sizeof(uint32_t)); - static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; - static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; - static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - - // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue - static constexpr uint32_t StagesPerMathWarpGroup = 2; - using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< - StagesPerMathWarpGroup, NumMmaWarpGroups>; - using MathWarpGroupOrderBarrierSharedStorage = - cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage< - MathWarpGroupOrderBarrier::SequenceDepth, - MathWarpGroupOrderBarrier::SequenceLength>; - - // Kernel level shared memory storage - struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - - alignas(16) TileSchedulerStorage scheduler; - - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - EpilogueTensorStorage epilogue; - MainloopTensorStorage mainloop; - } tensors; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - (void) workspace; - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* scheduler_workspace = workspace_ptr + workspace_offset; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), - hw_info, - TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles - ) - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_size = 0; - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - static constexpr uint32_t NumEpilogueSubTiles = 1; - static constexpr uint32_t NumAccumulatorMtxs = 1; - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ - CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -# endif - -// Any Tensor Op MMA Atom in the ISA is arch conditional. -#if ! defined(ENABLE_SM90_KERNEL_LEVEL) - printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - enum class WarpGroupRole { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - MainloopAux = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - - // TileScheduler pipeline - typename TileSchedulerPipeline::Params scheduler_pipeline_params; - typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params; - if constexpr (IsSchedDynamicPersistent) { - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer; - } - else { - scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; - } - scheduler_pipeline_params.producer_blockid = 0; - scheduler_pipeline_params.producer_arv_count = 1; - scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + NumMainloopLoadThreads + NumMMAThreads; - - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - - if (is_epi_load_needed) { - scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; - } - scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse); - - scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - scheduler_throttle_pipeline_params.dst_blockid = 0; - if (warp_group_role == WarpGroupRole::Producer && - producer_warp_role == ProducerWarpRole::Warp1) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Consumer; - } - // set role when it is for DMA warp in Mainloop - else if (warp_group_role == WarpGroupRole::Producer && - producer_warp_role == ProducerWarpRole::Mainloop) { - scheduler_throttle_pipeline_params.role = - TileSchedulerThrottlePipeline::ThreadCategory::Producer; - } - } - TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params); - TileSchedulerPipelineState scheduler_pipe_consumer_state; - - TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params); - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state; - TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state(); - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop - || producer_warp_role == ProducerWarpRole::MainloopAux)) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.num_producers = NumProducerThreads; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA Load WG will not participate in these Ordered Barrier syncs - params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); - params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - TileScheduler scheduler{params.scheduler}; - if constexpr (IsSchedDynamicPersistent) { - scheduler.set_data_ptr(shared_storage.scheduler.data()); - } - - if (warp_group_role == WarpGroupRole::Consumer1) { - - if constexpr (not IsSchedDynamicPersistent) { - // Advance 2nd Math WG to the next work tile for the startup - scheduler.advance_to_next_work(); - } - - // Advance 2nd Math WG pipeline states to the end of 1st Math WG - mainloop_pipe_consumer_state.advance(k_tile_count); - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - } - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); - - // Scheduler Producer Warp - if (producer_warp_role == ProducerWarpRole::Warp1) { - if constexpr (IsSchedDynamicPersistent) { - bool requires_clc_query = true; - TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state(); - - while (work_tile_info.is_valid()) { - - if (requires_clc_query) { - - // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. - scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state); - scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state); - ++scheduler_pipe_throttle_consumer_state; - - // Query next work tile - scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state); - } - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - work_tile_info = next_work_tile_info; - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - - // Terminal condition - if work_tile_info is end-of-grid, produce an extra invalid tile - scheduler_pipeline.producer_acquire(scheduler_pipe_producer_state); - scheduler.store_invalid_response(scheduler_pipe_producer_state); // Push invalid tile to smem - scheduler_pipeline.producer_commit(scheduler_pipe_producer_state); // Manual completion of transaction - ++scheduler_pipe_producer_state; - - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - scheduler_pipeline.producer_tail(scheduler_pipe_producer_state); - } - } // Scheduler Producer Warp End - else - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - bool do_load_order_arrive = true; - bool requires_clc_query = true; - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - - if (requires_clc_query) { - scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); - scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); - ++scheduler_pipe_throttle_producer_state; - } - - collective_mainloop.load( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - if constexpr (IsSchedDynamicPersistent) { - // Get next work tile - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - work_tile_info = next_work_tile_info; - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - else { - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if constexpr (IsSchedDynamicPersistent) { - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - } - - } // Mainloop Producer Warp End - - else if (producer_warp_role == ProducerWarpRole::MainloopAux) { - if constexpr (IsMainloopAuxiliaryLoadNeeded) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - collective_mainloop.load_auxiliary( - params.mainloop, - mainloop_pipeline, - mainloop_pipe_producer_state, - load_inputs, - blk_coord, - k_tile_iter, k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if constexpr (IsSchedDynamicPersistent) { - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, - scheduler_pipeline, - scheduler_pipe_consumer_state - ); - } - - } - } - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { - - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); - - bool do_load_order_wait = true; - while (work_tile_info.is_valid()) { - if (do_load_order_wait) { - load_order_barrier.wait(); - do_load_order_wait = false; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state = - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue - ); - - if constexpr (IsSchedDynamicPersistent) { - // Get next work tile - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - } - } - else { - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - - if constexpr (IsSchedDynamicPersistent) { - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - } - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - cutlass::arch::warpgroup_reg_alloc(); - - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - // It is possible to have work tiles start off invalid, - // so we have to check that first. - if (not work_tile_info.is_valid()) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - return; - } - #endif - - if constexpr (IsSchedDynamicPersistent) { - // Consumer0's initial tile is static. It starts consuming the 2nd tile. - if (warp_group_role == WarpGroupRole::Consumer0) { - ++scheduler_pipe_consumer_state; - } - - if (warp_group_role == WarpGroupRole::Consumer1) { - // Get next work tile - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - ++scheduler_pipe_consumer_state; - } - } - } - - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Allocate the accumulators for the (M,N) blk_shape - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - // Order two Math WG's MMA one after the other, helps hide Epilogue - math_wg_order_barrier.wait(); - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - warp_group_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Cue for next Math WG's MMA to start - math_wg_order_barrier.arrive(); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); - - } - #endif - - // Order two Math WG's Epilogue one after the other - math_wg_order_barrier.wait(); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - - // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels - // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives - // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. - auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state_next, - epi_store_pipeline, - epi_store_pipe_producer_state_next - ); - - // Update starting load/store pipeline states for the next tile - // state has already been incremented by 1 tile in collective calls, advance once again for ping pong - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - // Cue for next Math WG's Epilogue to start - math_wg_order_barrier.arrive(); - - if constexpr (IsSchedDynamicPersistent) { - // Get next work tile - auto [next_work_tile_info, increment_pipe] = - scheduler.fetch_next_work( - work_tile_info, scheduler_pipeline, scheduler_pipe_consumer_state); - - work_tile_info = next_work_tile_info; - if (increment_pipe) { - ++scheduler_pipe_consumer_state; - ++scheduler_pipe_consumer_state; - } - } - else { - // Get next work tile - scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(); - } - } // Scheduler work fetch loop - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp deleted file mode 100644 index e7cafde5338941287ae2628cdc7bcb36b9644c31..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ /dev/null @@ -1,417 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - static constexpr bool IsGdcEnabled = false; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(cute::is_void_v or cute::is_same_v, - "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - - // Kernel level shared memory storage - struct SharedStorage { - union TensorStorage { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; - using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; - static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); - - static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; - static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); - - static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static - size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - auto cluster_shape = Shape<_1,_1,_1>{}; - auto tile_shape = TileShape{}; - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - return TileScheduler::get_tiled_cta_shape_mnl( - problem_shape_MNKL, tile_shape, cluster_shape); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int warp_group_idx = canonical_warp_group_idx(); - CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); - WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // Represent the full tensors - Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) - - // Get the appropriate blocks for this thread block -- potential for thread block locality - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - TiledMma tiled_mma; - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - // Compute m_coord, n_coord, and l_coord with their post-tiled shapes - auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); - auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with m_coord and n_coord - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - // Get pipeline iterators and increments from tensor shapes - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - auto k_tile_count = size<2>(gA); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - // Wait for all threads in the thread block - __syncthreads(); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; - - if (warp_group_role == WarpGroupRole::Producer) { - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); - - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, - gB, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - epi_load_pipe_producer_state = - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - thread_idx, - shared_storage.tensors.epilogue - ); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } - else if (warp_group_role == WarpGroupRole::Consumer) { - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - warp_group_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - - // Epilogue and write to gD - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - } -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp deleted file mode 100644 index 1d35ff2dc8c3992e7942a0be5da929febd771cae..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ /dev/null @@ -1,515 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/tensor.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - static constexpr bool IsGdcEnabled = false; - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; - using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; - static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same"); - - static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; - static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); - - static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - hw_info, - scheduler - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static - size_t - get_workspace_size(Arguments const& args) { - TileScheduler t; - return t.template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - } - - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - TileScheduler t; - static constexpr uint32_t NumEpilogueSubTiles = 1; - static constexpr uint32_t NumAccumulatorMtxs = 1; - return t.template initialize_workspace( - args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */ - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int mma_thread_idx = thread_idx % size(TiledMma{}); - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int warp_group_idx = canonical_warp_group_idx(); - CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); - WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // Represent the full tensors - Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; - - // Wait for all threads in the thread block - __syncthreads(); - - if (warp_group_role == WarpGroupRole::Producer) { - - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<2>(gA)), shape<2>(gA)); - - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); - - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, - gB, - k_tile_iter, work_k_tile_count, - residue_mnk, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler) && - collective_epilogue.is_producer_load_needed()) { - epi_load_pipe_producer_state = - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer) { - - bool do_store_tail = false; - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - // Allocate the the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - work_k_tile_count - ); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue - ); - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); - work_tile_info = next_work_tile_info; - } // Scheduler work fetch loop - - if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state - ); - } - } // Consumer Warp Groups End -#endif - } - -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp deleted file mode 100644 index be086f0c9c5dcd21d68dadc0d67ac1c3844373f8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ /dev/null @@ -1,527 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/fast_math.h" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/gemm/kernel/gemm_universal_decl.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" - -#include "cute/tensor.hpp" -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel { - -/////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> -class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - static constexpr bool IsGdcEnabled = false; - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(!cute::is_same_v, "Ping-pong kernel does not currently support stream-K scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; - using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; - static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same"); - - static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumMmaWarpGroups = 2 * cute::size(TiledMma{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; - static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); - static_assert(NumMmaWarpGroups == 2, "Pingpong kernel requires 2 MMA warp groups."); - - static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue - static constexpr uint32_t StagesPerMathWarpGroup = 2; - using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< - StagesPerMathWarpGroup, NumMmaWarpGroups>; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _1> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16, _1> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - (void) workspace; - auto problem_shape = args.problem_shape; - if constexpr (detail::Has_SwapAB_v) { - // swap M/N - get<0>(problem_shape) = get<1>(args.problem_shape); - get<1>(problem_shape) = get<0>(args.problem_shape); - } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - // Get maximum number of clusters that could co-exist on the target device - int max_active_clusters = args.hw_info.max_active_clusters; - if (max_active_clusters <= 0) { - max_active_clusters = 0; - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); - } - else { - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); - } - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); - - return { - args.mode, - problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - hw_info, - scheduler - }; - } - - static bool - can_implement(Arguments const& args) { - bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - - return implementable; - } - - static - size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int warp_group_idx = canonical_warp_group_idx(); - CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); - WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; - int warp_group_consumer_idx = warp_group_idx - NumLoadWarpGroups; - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - mainloop_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; // only 1 WG consumes at a time - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; // only 1 WG consumes at a time - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA Load WG will not participate in these Ordered Barrier syncs - params_math_wg_order_barrier.group_id = warp_group_consumer_idx; - params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // Represent the full tensors - Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - TileScheduler scheduler{params.scheduler}; - - if (warp_group_consumer_idx == 1) { - // Advance 2nd Math WG to the next work tile for the startup - scheduler.advance_to_next_work(); - // Advance 2nd Math WG pipeline states to the end of 1st Math WG - mainloop_pipe_consumer_state.advance(k_tile_count); - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - } - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; - - // Wait for all threads in the thread block - __syncthreads(); - - if (warp_group_role == WarpGroupRole::Producer) { - - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); - - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, - gB, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - if (collective_epilogue.is_producer_load_needed()) { - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - // Update starting pipeline state for the next tile - epi_load_pipe_producer_state.advance(c_tile_count); - } - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - if (collective_epilogue.is_producer_load_needed()) { - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer) { - - while (work_tile_info.is_valid()) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Allocate the the accumulators for the (M,N) blk_shape - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - // Order two Math WG's MMA one after the other, helps hide Epilogue - math_wg_order_barrier.wait(); - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Cue for next Math WG's MMA to start - math_wg_order_barrier.arrive(); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - - // Order two Math WG's Epilogue one after the other - math_wg_order_barrier.wait(); - - // Epilogue and write to gD - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - // Update starting load/store pipeline states for the next tile - epi_load_pipe_consumer_state.advance(c_tile_count * NumMmaWarpGroups); - epi_store_pipe_producer_state.advance(d_tile_count * NumMmaWarpGroups); - - // Wait for all TMA stores to complete - epi_store_pipeline.producer_tail(epi_store_pipe_producer_state); - - // Cue for next Math WG's Epilogue to start - math_wg_order_barrier.arrive(); - - // Get next work tile - scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp deleted file mode 100644 index dd90d48f1bd82e9d14cdc41dd93f402d8bd20363..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ /dev/null @@ -1,153 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" - -namespace cutlass::gemm::kernel::detail { - -/////////////////////////////////////////////////////////////////////////////// - -// Persistent Thread Block (TB) scheduler -class PersistentTileSchedulerSm90: -public StaticPersistentTileScheduler { - - using BaseScheduler = StaticPersistentTileScheduler; -public: - using StaticPersistentTileScheduler::StaticPersistentTileScheduler; - using Params = PersistentTileSchedulerSm90Params; - using RasterOrder = typename Params::RasterOrder; - using RasterOrderOptions = typename Params::RasterOrderOptions; - using Arguments = BaseScheduler::Arguments; - - static constexpr bool IsDynamicPersistent = false; - - using Pipeline = PipelineEmpty; - using PipelineStorage = typename Pipeline::SharedStorage; - using ThrottlePipeline = PipelineEmpty; - using ThrottlePipelineStorage = typename ThrottlePipeline::SharedStorage; - - struct CLCResponse {}; - - class SharedStorage { - public: - CUTLASS_DEVICE PipelineStorage pipeline() { return PipelineStorage{}; } - CUTLASS_DEVICE ThrottlePipelineStorage throttle_pipeline() { return ThrottlePipelineStorage{}; } - CUTLASS_DEVICE CLCResponse* data() { return nullptr; } - }; - - // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle - static CUTLASS_DEVICE - cute::tuple - get_work_idx_m_and_n( - uint64_t blk_per_grid_dim, - FastDivmodU64Pow2 const& divmod_cluster_shape_major, - FastDivmodU64Pow2 const& divmod_cluster_shape_minor, - FastDivmodU64 const& divmod_cluster_blk_major, - int32_t log_swizzle_size, - RasterOrder raster_order) { - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - return get_work_idx_m_and_n( - blk_per_grid_dim, - divmod_cluster_shape_major, - divmod_cluster_shape_minor, - divmod_cluster_blk_major, - log_swizzle_size, - raster_order, - cta_m_in_cluster, - cta_n_in_cluster - ); - } - - static CUTLASS_DEVICE - cute::tuple - get_work_idx_m_and_n( - uint64_t blk_per_grid_dim, - FastDivmodU64Pow2 const& divmod_cluster_shape_major, - FastDivmodU64Pow2 const& divmod_cluster_shape_minor, - FastDivmodU64 const& divmod_cluster_blk_major, - int32_t log_swizzle_size, - RasterOrder raster_order, - uint64_t cta_m_in_cluster, - uint64_t cta_n_in_cluster) { - - uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; - divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); - - if (raster_order == RasterOrder::AlongN) { - cluster_minor_offset = cta_m_in_cluster; - } - else { - cluster_minor_offset = cta_n_in_cluster; - } - - uint64_t cluster_idx_minor, cluster_idx_major; - - uint64_t cluster_idx_minor_div_swizzle, extra, offset; - - offset = cluster_id & ((1 << log_swizzle_size) - 1); - extra = cluster_id >> log_swizzle_size; - - divmod_cluster_blk_major(cluster_idx_minor_div_swizzle, cluster_idx_major, extra); - - cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; - - auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + - cluster_minor_offset); - auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + - cluster_major_offset); - - if (raster_order == RasterOrder::AlongN) { - return {minor_work_idx, major_work_idx}; - } - else { - return {major_work_idx, minor_work_idx}; - } - - } - - // The basic tile scheduler does not require any additional workspace - template - static size_t - get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, - uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - -}; - -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp deleted file mode 100644 index 92749b196640e5682a0aa09e5c9c4d8c8c08f2f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ /dev/null @@ -1,586 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/fast_math.h" -#include "cutlass/gemm_coord.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/arch/cluster_sm90.hpp" - -namespace cutlass::gemm::kernel::detail { - -/////////////////////////////////////////////////////////////////////////////// - -// Persistent Thread Block (TB) scheduler -template -class PersistentTileSchedulerSm90Group { - // - // Data members - // - -private: - uint64_t current_work_linear_idx_ = 0; - uint64_t total_grid_size_ = 0; - - // Tracking current group, its starting linear idx and total tiles - struct GroupInfo { - int group_idx = 0; - uint64_t start_linear_idx = 0; - uint64_t total_tiles = 0; - uint64_t problem_blocks_along_raster_order = 0; - } current_group_info_; - -public: - struct WorkTileInfo { - int32_t M_idx = 0; - int32_t N_idx = 0; - int32_t L_idx = 0; - int32_t is_valid_tile = 0; - - CUTLASS_HOST_DEVICE - bool - is_valid() const { - return is_valid_tile != 0; - } - - CUTLASS_HOST_DEVICE - static WorkTileInfo - invalid_work_tile() { - return {-1, -1, -1, 0}; - } - - CUTLASS_HOST_DEVICE - bool - is_final_split(uint32_t k_tiles_per_output_tile) const { - return true; - } - - CUTLASS_HOST_DEVICE - int32_t - reduction_subtile_idx() const { - return -1; - } - }; - - using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; - using Params = PersistentTileSchedulerSm90GroupParams; - using RasterOrder = typename Params::RasterOrder; - using RasterOrderOptions = typename Params::RasterOrderOptions; - static constexpr bool IsDynamicPersistent = false; - - // We need to hard code the number of stages here since the scheduling is static - // and it can benefit from a larger number of stages without worrying about imbalances. - - using Pipeline = PipelineAsync; - - // Call out the types here to work around a bug in MSVC. - - // using PipelineStorage = typename Pipeline::SharedStorage; - // using PipelineState = typename Pipeline::PipelineState; - using PipelineStorage = cutlass::PipelineDetail::PipelineAsyncSharedStorage; - using PipelineState = cutlass::PipelineDetail::PipelineAsyncPipelineState; - - using ThrottlePipeline = PipelineEmpty; - using ThrottlePipelineStorage = typename PipelineEmpty::SharedStorage; - using SchedulerResponse = WorkTileInfo; - - class SharedStorage { - public: - CUTLASS_DEVICE PipelineStorage pipeline() { return pipeline_; } - // Pipeline throttle is not needed here as the scheduling is not dynamic. - CUTLASS_DEVICE ThrottlePipelineStorage throttle_pipeline() { return ThrottlePipelineStorage{}; } - CUTLASS_DEVICE SchedulerResponse* data() { return data_; } - - private: - alignas(16) PipelineStorage pipeline_; - alignas(16) SchedulerResponse data_[SchedulerPipelineStageCount]; - }; - - struct Arguments { - int max_swizzle_size = 1; - // Not applying Heuristics for Grouped problems, since largest dimension can change per group - RasterOrderOptions raster_order = RasterOrderOptions::AlongM; - }; - - // Sink scheduler params as a member - Params scheduler_params; - SchedulerResponse *response_ptr_ = nullptr; - ProblemShape cached_problem_shapes_[2]; - - // - // Methods - // - - template - static Params - to_underlying_arguments( - GroupProblemShape problem_shapes, - TileShape tile_shape, - ClusterShape cluster_shape, - KernelHardwareInfo const& hw_info, - Arguments const& arguments, - [[maybe_unused]] void* workspace=nullptr, - [[maybe_unused]] const uint32_t epilogue_subtile = 1, - [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u - ) { - - // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic - static_assert(cute::is_static::value); - static_assert(cute::is_static::value); - - dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes, - hw_info, - tile_shape, cluster_shape); - - Params params; - params.initialize( - problem_blocks, - problem_shapes, - to_gemm_coord(tile_shape), - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order - ); - - return params; - } - - // Given the inputs, computes the physical grid we should launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - [[maybe_unused]] Params const& params, - GroupProblemShape const& problem_shapes, - TileShape tile_shape, - ClusterShape cluster_shape, - KernelHardwareInfo hw_info, - Arguments arguments, - bool truncate_by_problem_size=true) { - - dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes, - hw_info, - tile_shape, cluster_shape); - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order, - /* truncate_by_problem_size = */true - ); - } - - // Given the inputs, computes the total number of output blocks this problem will compute over - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(GroupProblemShape const& problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { - int groups = problem_shapes.groups(); - uint32_t total_ctas = 0; - uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here - - // If host problem shapes are not provided. - if (!problem_shapes.is_host_problem_shape_available()) { - total_ctas = hw_info.sm_count; - } - // If host problem shapes are provided, make a better decision about possibility to launch smaller grid. - else { - for (int group = 0; group < groups; group++) { - auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape))); - auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape))); - auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape)); - auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape)); - total_ctas += problem_blocks_m * problem_blocks_n; - } - } - - return Params::get_tiled_cta_shape_mnl( - to_gemm_coord(cluster_shape), - total_ctas, cta_in_N_dim - ); - } - - static bool - can_implement(Arguments const& args) { - return true; - } - - PersistentTileSchedulerSm90Group() = default; - - CUTLASS_DEVICE explicit PersistentTileSchedulerSm90Group(Params const& params_, SchedulerResponse* response_ptr) : scheduler_params(params_), response_ptr_(response_ptr) { - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) - if (scheduler_params.raster_order_ == RasterOrder::AlongN) { - current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); - } - else { - current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); - } - - int lane_idx = canonical_lane_idx(); - if (lane_idx < params_.problem_shapes_.groups()) { - cached_problem_shapes_[1] = params_.problem_shapes_.get_problem_shape(lane_idx); - } - - total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); - uint64_t ctas_along_m, ctas_along_n; - ProblemShape problem_shape = params_.problem_shapes_.get_problem_shape(0); - if (is_tuple(problem_shape))>::value || - is_tuple(problem_shape))>::value) { - ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape), scheduler_params.cta_shape_.m())); - ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape), scheduler_params.cta_shape_.n())); - } - else { - ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(problem_shape) + scheduler_params.divmod_cta_shape_m_.divisor - 1); - ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(problem_shape) + scheduler_params.divmod_cta_shape_n_.divisor - 1); - } - auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m()); - auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n()); - current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n; - current_group_info_.problem_blocks_along_raster_order = params_.raster_order_ == RasterOrder::AlongN ? problem_blocks_n : problem_blocks_m; - -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif - } - - // get work_idx_m, work_idx_n from linear_idx while applying swizzle - template - static - CUTLASS_DEVICE - WorkTileInfo - get_work_idx_m_and_n( - uint64_t linear_idx, - GroupInfo& group_info, - GroupProblemShape &problem_shapes, - ProblemShape (&cached_problem_shapes)[2], - GemmCoord cta_shape, - GemmCoord cluster_shape, - FastDivmodU64Pow2 const& divmod_cluster_shape_major, - FastDivmodU64Pow2 const& divmod_cluster_shape_minor, - FastDivmodU64 const& divmod_cta_shape_m, - FastDivmodU64 const& divmod_cta_shape_n, - int32_t log_swizzle_size, - RasterOrder raster_order) { - - int32_t valid_tile = 1; - - // Use a warp to "speculatively" check if the work tile maps to the next 32 groups - int lane_idx = canonical_lane_idx(); - int total_problem_groups = problem_shapes.groups(); - - if (linear_idx >= group_info.total_tiles + group_info.start_linear_idx) { - group_info.group_idx += lane_idx; - for ( ; ; group_info.group_idx += NumThreadsPerWarp) { - cached_problem_shapes[0] = cached_problem_shapes[1]; - if (group_info.group_idx + NumThreadsPerWarp < total_problem_groups) { - cached_problem_shapes[1] = problem_shapes.get_problem_shape(group_info.group_idx + NumThreadsPerWarp); - } - if (group_info.group_idx < total_problem_groups) { - uint64_t ctas_along_m, ctas_along_n; - if (is_tuple(cached_problem_shapes[0]))>::value || - is_tuple(cached_problem_shapes[0]))>::value) { - ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(cached_problem_shapes[0]), cta_shape.m())); - ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(cached_problem_shapes[0]), cta_shape.n())); - } - else { - ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(cached_problem_shapes[0]) + divmod_cta_shape_m.divisor - 1); - ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(cached_problem_shapes[0]) + divmod_cta_shape_n.divisor - 1); - } - auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); - group_info.problem_blocks_along_raster_order = raster_order == RasterOrder::AlongN ? problem_blocks_n : problem_blocks_m; - group_info.total_tiles = problem_blocks_m * problem_blocks_n; - } else { - group_info.total_tiles = INT_MAX; - } - - auto curr_total_tiles = group_info.total_tiles; - - // Calculate prefix sum for start_linear_idx. - #pragma unroll - for (int i = 1; i < NumThreadsPerWarp; i *= 2) { - auto n = __shfl_up_sync(0xffffffff, curr_total_tiles, i); - curr_total_tiles = lane_idx >= i ? curr_total_tiles + n : curr_total_tiles; - } - group_info.start_linear_idx += curr_total_tiles - group_info.total_tiles; - - uint32_t thread_succeed = __ballot_sync(0xffffffff, linear_idx < group_info.start_linear_idx + group_info.total_tiles); - if (thread_succeed) { - // Use the first succeeding thread. - int first_succeeding_thread = __ffs(thread_succeed) - 1; - group_info.group_idx = __shfl_sync(0xffffffff, group_info.group_idx, first_succeeding_thread); - group_info.start_linear_idx = __shfl_sync(0xffffffff, group_info.start_linear_idx, first_succeeding_thread); - group_info.total_tiles = __shfl_sync(0xffffffff, group_info.total_tiles, first_succeeding_thread); - group_info.problem_blocks_along_raster_order = __shfl_sync(0xffffffff, group_info.problem_blocks_along_raster_order, first_succeeding_thread); - if (group_info.group_idx + lane_idx < total_problem_groups) { - cached_problem_shapes[1] = problem_shapes.get_problem_shape(group_info.group_idx + lane_idx); - } - break; - } - // Update the start_linear_idx for all threads so that they're ready for the next iteration. - group_info.start_linear_idx = __shfl_sync(0xffffffff, group_info.start_linear_idx + group_info.total_tiles, NumThreadsPerWarp - 1); - } - } - - if (group_info.group_idx >= total_problem_groups) { - return WorkTileInfo::invalid_work_tile(); - } - - uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; - uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); - divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); - - // With static schedulers, we launch grid such that all cluster are linear (1-D) order, i.e., - // there can only be one cluster in the minor dimension. get_grid_shape() in scheduler params - // put cluster_shape.m/n() as the minor dimension based on raster order AlongN/M resp. - // Therefore, the offset of a CTA (inside a cluster) in the minor dimension can be directly be - // inferred by the blockIdx along the minor dimension. - if (raster_order == RasterOrder::AlongN) { - cluster_minor_offset = blockIdx.x; - } - else { - cluster_minor_offset = blockIdx.y; - } - - uint64_t cluster_idx_minor, cluster_idx_major; - - uint64_t cluster_idx_minor_div_swizzle, extra, offset; - - offset = cluster_id & ((1 << log_swizzle_size) - 1); - extra = cluster_id >> log_swizzle_size; - - uint64_t curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(group_info.problem_blocks_along_raster_order); - - cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; - cluster_idx_major = extra % curr_group_cluster_blk_major; - - cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; - - auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + - cluster_minor_offset); - auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + - cluster_major_offset); - - if (raster_order == RasterOrder::AlongN) { - return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; - } - else { - return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; - } - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx) { - if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) { - return WorkTileInfo::invalid_work_tile(); - } - return get_work_idx_m_and_n( - linear_idx, - current_group_info_, - scheduler_params.problem_shapes_, - cached_problem_shapes_, - scheduler_params.cta_shape_, - scheduler_params.cluster_shape_, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.divmod_cta_shape_m_, - scheduler_params.divmod_cta_shape_n_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); - } - template - CUTLASS_DEVICE - auto - advance_to_next_work( - TileSchedulerPipeline& scheduler_pipeline, - TileSchedulerPipelineState scheduler_pipe_producer_state, - uint32_t advance_count = 1) { - - current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); - auto work_tile = get_current_work_for_linear_idx(current_work_linear_idx_); - scheduler_pipeline.producer_acquire(scheduler_pipe_producer_state); - if (cute::elect_one_sync()) { - response_ptr_[scheduler_pipe_producer_state.index()] = work_tile; - cutlass::arch::fence_view_async_shared(); - scheduler_pipeline.producer_commit(scheduler_pipe_producer_state); - } - return cute::make_tuple(work_tile, true); - } - - // Returns whether the block assigned this work should compute the epilogue for the corresponding - // output tile. For the basic tile scheduler, this is always true. - CUTLASS_HOST_DEVICE - static bool - compute_epilogue(WorkTileInfo const&, Params const&) { - return true; - } - - // Performs the reduction across splits for a given output tile. Since this scheduler does - // not split output tiles, no reduction is needed. - template - CUTLASS_DEVICE - static void - fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} - - // Returns whether the current WorkTileInfo passed in should continue to be used. Since - // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo - // passed in should not be used after having been processed. - CUTLASS_DEVICE - static bool - continue_current_work(WorkTileInfo&) { - return false; - } - - // The basic tile scheduler does not require any additional workspace - template - static size_t - get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, - uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - template - CUTLASS_HOST_DEVICE - static int - get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape_MNKL problem_shape, TileShape tile_shape) { - // All work units returned by this scheduler cover the entire K iteration - // space of the output tile assigned to the work unit. - return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); - } - - CUTLASS_HOST_DEVICE - static uint32_t - get_work_k_tile_start(WorkTileInfo const&) { - // All work units returned by this scheduler start from K tile 0 - return 0u; - } - - CUTLASS_DEVICE - static bool - need_separate_reduction(Params const& params) { - return false; - } - - CUTLASS_DEVICE - bool - is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { - return false; - } - - CUTLASS_DEVICE - uint32_t - epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { - return 0; - } - - template - CUTLASS_DEVICE - void - separate_reduction( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - // Shares the accumulator set with peers in the global workspace - template - CUTLASS_DEVICE - static void - share( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - CUTLASS_DEVICE - static bool - valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { - return true; - } - - CUTLASS_DEVICE - static bool - requires_separate_reduction(Params const& params) { - return false; - } - - // Kernel helper function to get next work tile - template - CUTLASS_DEVICE - auto - fetch_next_work( - WorkTileInfo work_tile_info, - TileSchedulerPipeline& scheduler_pipeline, - TileSchedulerPipelineState scheduler_pipe_consumer_state) { - - if (continue_current_work(work_tile_info)) { - return cute::make_tuple(work_tile_info, true); - } - scheduler_pipeline.consumer_wait(scheduler_pipe_consumer_state); - auto work_tile = response_ptr_[scheduler_pipe_consumer_state.index()]; - cutlass::arch::fence_view_async_shared(); - scheduler_pipeline.consumer_release(scheduler_pipe_consumer_state); - - return cute::make_tuple(work_tile, true); - } - - // Returns the initial work tile info that will be computed over - template - CUTLASS_DEVICE - auto - initial_work_tile_info(ClusterShape) { - return get_current_work_for_linear_idx(current_work_linear_idx_); - } -}; - -} // namespace cutlass::gemm::kernel::detail diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp deleted file mode 100644 index a298e06bf4e65b068d1cb1935d9325551b428c68..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ /dev/null @@ -1,1113 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/barrier.h" -#include "cutlass/block_striped.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cute/layout.hpp" -#include "cute/tensor.hpp" - -namespace cutlass::gemm::kernel::detail { - -// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition -template < - class TileShape, - class ClusterShape -> -class PersistentTileSchedulerSm90StreamK { - // - // Data members - // - -private: - using UnderlyingScheduler = PersistentTileSchedulerSm90; - -private: - using UnderlyingArguments = typename UnderlyingScheduler::Arguments; - using UnderlyingParams = typename UnderlyingScheduler::Params; - - dim3 block_id_in_cluster_; - uint64_t current_work_linear_idx_ = 0; - uint32_t unit_iter_start_ = 0; - -public: - - using RasterOrder = UnderlyingScheduler::RasterOrder; - using RasterOrderOptions = UnderlyingScheduler::RasterOrderOptions; - static constexpr bool IsDynamicPersistent = false; - - using Pipeline = PipelineEmpty; - using PipelineStorage = typename Pipeline::SharedStorage; - using ThrottlePipeline = PipelineEmpty; - using ThrottlePipelineStorage = typename ThrottlePipeline::SharedStorage; - struct CLCResponse {}; - - class SharedStorage { - public: - CUTLASS_DEVICE PipelineStorage pipeline() { return PipelineStorage{}; } - CUTLASS_DEVICE ThrottlePipelineStorage throttle_pipeline() { return ThrottlePipelineStorage{}; } - CUTLASS_DEVICE CLCResponse* data() { return nullptr; } - }; - - // Use a dummy barrier manager to simply get the type used to store the barrier - using BarrierType = typename NamedBarrierManager<1>::T; - - using Params = PersistentTileSchedulerSm90StreamKParams; - using ReductionMode = Params::ReductionMode; - using DecompositionMode = Params::DecompositionMode; - - struct WorkTileInfo { - int32_t M_idx = 0; - int32_t N_idx = 0; - int32_t K_idx = 0; - int32_t L_idx = 0; - - // Number of k tiles to compute for this unit of work. For stream-K, this - // can indicate the number of K tiles across multiple output tiles. - uint32_t k_tile_count = 0; - - // Number of k tiles remaining for the work unit as a whole - uint32_t k_tile_remaining = 0; - - // Whether this unit of work is the final split for the given tile - bool is_separate_reduction = false; - - CUTLASS_HOST_DEVICE - bool - is_valid() const { - // A work tile that computes no K tiles is invalid unless it is a separate-reduction work tile - // (which only performs reduction and epilogue) - return k_tile_count > 0 || is_separate_reduction; - } - - CUTLASS_HOST_DEVICE - bool - is_reduction_unit() const { - return is_separate_reduction; - } - - CUTLASS_HOST_DEVICE - int32_t - reduction_subtile_idx() const { - // For separate reduction units, the K_idx of the work tile is unused. - // Therefore, we override it to contain the subtile of that the reduction - // unit operates on. - return is_reduction_unit() ? K_idx : -1; - } - - CUTLASS_HOST_DEVICE - void - setup_separate_reduction(int32_t epilogue_subtile_idx) { - // Set the epilogue subtile in the K_idx, since this is otherwise unused - // by separate reduction units. - K_idx = epilogue_subtile_idx; - - is_separate_reduction = true; - k_tile_count = 0; - // Clean up remaining k tiles - k_tile_remaining = 0; - } - - CUTLASS_HOST_DEVICE - static WorkTileInfo - invalid_work_tile() { - return {-1, -1, -1, -1, 0}; - } - - CUTLASS_HOST_DEVICE - bool - is_final_split(uint32_t k_tiles_per_output_tile) const { - return (K_idx + k_tile_count) == k_tiles_per_output_tile; - } - }; - - struct Arguments { - - Arguments() = default; - Arguments(Arguments const&) = default; - Arguments(Arguments&&) = default; - - CUTLASS_HOST_DEVICE - Arguments& - operator=(Arguments const& args) { - splits = args.splits; - max_swizzle_size = args.max_swizzle_size; - raster_order = args.raster_order; - reduction_mode = args.reduction_mode; - decomposition_mode = args.decomposition_mode; - return *this; - } - - CUTLASS_HOST_DEVICE - Arguments& - operator=(Arguments&& args) noexcept { - splits = args.splits; - max_swizzle_size = args.max_swizzle_size; - raster_order = args.raster_order; - reduction_mode = args.reduction_mode; - decomposition_mode = args.decomposition_mode; - return *this; - } - - CUTLASS_HOST_DEVICE - Arguments(int splits_) : splits(splits_) {} - - CUTLASS_HOST_DEVICE - Arguments(int splits_, int max_swizzle_size_, RasterOrderOptions raster_order_, DecompositionMode decomposition_mode_) : - splits(splits_), - max_swizzle_size(max_swizzle_size_), - raster_order(raster_order_), - decomposition_mode(decomposition_mode_) {} - - // The splitting factor to be used in a split-K decomposition of the problem. - // If this is set to a value greater than 1, stream-K decomposition logic - // is bypassed in favor of a split-K decomposition. - int splits = 1; - int max_swizzle_size = 1; - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; - ReductionMode reduction_mode = ReductionMode::Deterministic; - DecompositionMode decomposition_mode = DecompositionMode::Heuristic; - }; - - // Sink scheduler params as a member - Params scheduler_params; - - // - // Methods - // - - template - static Params - to_underlying_arguments( - ProblemShape problem_shape, - TileShape tile_shape, - ClusterShape cluster_shape, - KernelHardwareInfo const& hw_info, - Arguments const& args, - void* workspace, - const uint32_t epilogue_subtile = 1, - [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) { - - static_assert(cute::is_static::value); - static_assert(cute::is_static::value); - - auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{}); - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); - - Params params; - params.initialize( - problem_blocks, - k_tile_per_output_tile, - to_gemm_coord(cluster_shape), - hw_info, - args.splits, - args.max_swizzle_size, - args.raster_order, - args.reduction_mode, - args.decomposition_mode, - workspace, - epilogue_subtile - ); - return params; - } - - static bool - can_implement(Arguments const& args) { - // Split count > 1 is only valid for heuristic and split-K decomposition modes - return (args.splits == 1 || - args.decomposition_mode == DecompositionMode::Heuristic || - args.decomposition_mode == DecompositionMode::SplitK); - } - - CUTLASS_HOST_DEVICE - PersistentTileSchedulerSm90StreamK() { }; - - CUTLASS_DEVICE - PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_), block_id_in_cluster_(cute::block_id_in_cluster()) { - if (params_.raster_order_ == RasterOrder::AlongN) { - current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); - } - else { - current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); - } - - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work() { - return get_current_work_for_linear_idx(unit_iter_start_, current_work_linear_idx_, block_id_in_cluster_, scheduler_params); - } - - CUTLASS_DEVICE - static WorkTileInfo - get_current_work_for_linear_idx(uint32_t &unit_iter_start, uint64_t linear_idx, dim3 block_id_in_cluster, Params const& params) { - // The maximum number of work units is units_per_problem_ * splits_. - // The multiplication by splits_ is used for handling split-K, in which - // units_per_problem_ is equal to the total number of output tiles. To account - // for the fact that we have splits_ peers per output tile, we multiply this - // value by splits_. For stream-K, this multiplication ends up being a no-op - // because splits_ is set to 1 for stream-K. - if(linear_idx >= (params.units_per_problem_ * params.divmod_splits_.divisor + params.separate_reduction_units_)) { - // Invalid work. Return an empty result. - return WorkTileInfo::invalid_work_tile(); - } - - WorkTileInfo work_tile_info; - assign_work(params, linear_idx, block_id_in_cluster, work_tile_info, unit_iter_start); - return work_tile_info; - } - - // Returns whether the current work_tile_info passed in should continue to be used. This - // occurs only in the stream-K decomposition with stream-K work units, which encompass - // work over multiple output tiles. If the current work_tile_info should continue to be - // used, it is updated to advance to the next output tile it should cover. - CUTLASS_DEVICE - bool - continue_current_work(WorkTileInfo& work_tile_info) const { - return continue_current_work_for_linear_idx( - current_work_linear_idx_, unit_iter_start_, block_id_in_cluster_, work_tile_info, scheduler_params); - } - - CUTLASS_DEVICE - static bool - continue_current_work_for_linear_idx( - uint64_t linear_idx, - uint32_t unit_iter_start, - dim3 block_id_in_cluster, - WorkTileInfo& work_tile_info, - Params const& params) { - - work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count; - - if (work_tile_info.k_tile_remaining == 0) { - return false; - } - fast_assign_work(unit_iter_start, params, linear_idx, block_id_in_cluster, work_tile_info); - return work_tile_info.is_valid(); - } - - CUTLASS_DEVICE - void - advance_to_next_work(uint32_t advance_count = 1) { - current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count); - } - - CUTLASS_DEVICE - bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const { - // Never pass this by reference; it needs a copy, - // because continue_current_work will modify it. - if (continue_current_work(work_tile_info)) { - return false; - } - return not get_current_work_for_linear_idx( - unit_iter_start_, - current_work_linear_idx_ + ( - uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count) - ), - block_id_in_cluster_, - scheduler_params - ).is_valid(); - } - - // Given the inputs, computes the total number of output blocks this problem will compute over - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(ProblemShape problem_shape_mnkl, TileShape cta_shape, ClusterShape cluster_shape) { - return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); - } - - // Given the cluster shape, computes the physical grid we should launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - [[maybe_unused]] Params const& params, - ProblemShape problem_shape, - TileShape tile_shape, - ClusterShape cluster_shape, - KernelHardwareInfo hw_info, - Arguments arguments) { - - auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{}); - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order - ); - } - - // Returns whether fixup is needed for `work_tile_info`. - CUTLASS_HOST_DEVICE - static bool - requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) { - // Fixup is not needed for invalid or data-parallel tiles - return work_tile_info.is_valid() && work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor; - } - - CUTLASS_HOST_DEVICE - static bool - requires_separate_reduction(Params const& params) { - return params.requires_separate_reduction(); - } - - // When the work tile is not special for reduction, it's valid. Otherwise need to skip - // global loading that producer warpgroup do, also math computation that consumer warpgroup do. - CUTLASS_DEVICE - static bool - valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { - return !work_tile_info.is_reduction_unit(); - } - - // Performs the reduction across splits for a given output tile. - template - CUTLASS_DEVICE - static void - fixup( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - static constexpr uint32_t Offset = static_cast(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); - static constexpr uint32_t MaxNumNamedBarriers = 2; - using BarrierManager = NamedBarrierManager; - return fixup_helper( - params, work_tile_info, accumulators, num_barriers, barrier_idx); - } - - // Helper for performing the reduction across splits for a given output tile. - template - CUTLASS_DEVICE - static void - fixup_helper( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx, - uint32_t num_accumulator_mtxs = 1, - uint32_t idx_accumulator_mtxs = 0) { - - using ElementAccumulator = typename FrgTensorC::value_type; - - if (!requires_fixup(params, work_tile_info)) { - return; - } - uint64_t tile_idx = output_tile_index(params, work_tile_info); - - // Index of the lock on which to wait - uint64_t lock_idx = (tile_idx * num_barriers) + barrier_idx; - - uint64_t reduction_tile_idx = tile_idx; - uint64_t num_peers = 0; - uint64_t reduction_peer_offset = 0; - if ( - params.requires_separate_reduction() - ) { - // If separate reduction is to be performed, each stream-K unit writes its partials - // to a separate portion of the workspace. There are as many of these portions as there - // are peers for a given output tile, so we multiply the tile index by the maximum peer count. - auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, work_tile_info); - auto peer_id_in_output_tile = my_peer_id - first_peer_id; - num_peers = last_peer_id - first_peer_id + 1; - reduction_tile_idx = tile_idx * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); - reduction_peer_offset = peer_id_in_output_tile * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; - } - - // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. - // Thus, the start of the reduction space is the same across all threads in a warp group. - uint64_t reduction_offset_base = (static_cast(cute::size<0>(TileShape{})) * static_cast(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) + - (static_cast(size(accumulators)) * barrier_idx * BarrierManager::ThreadCount * num_accumulator_mtxs) - + static_cast(size(accumulators)) * BarrierManager::ThreadCount * idx_accumulator_mtxs; - uint64_t reduction_offset = reduction_offset_base + reduction_peer_offset; - - ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; - - using AccumulatorArrayT = Array; - using BlockStripedReduceT = BlockStripedReduce; - - AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); - AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); - - uint32_t barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; - - // The number of tiles for which reduction is required is either: - // (a) the total number of output tiles (in the case of split-K) - // (b) the number of stream-K tiles (potentially multiplied by peer count if using separate reduction) - // To calculate the total number of output tiles in the split-K case, we - // note that, in the split-K case, the units_per_problem_ member of Params will be - // the total number of output tiles. - uint32_t reduction_tiles = 0; - if (params.divmod_splits_.divisor > 1) { - reduction_tiles = params.units_per_problem_; - } - else if ( - params.requires_separate_reduction() - ) { - reduction_tiles = params.sk_tiles_ * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); - } - else { - reduction_tiles = params.sk_tiles_; - } - - uint64_t reduction_workspace_size = Params::get_reduction_workspace_size( - reduction_tiles, to_gemm_coord(TileShape{}), sizeof_bits::value, num_accumulator_mtxs); - BarrierType* lock_workspace = reinterpret_cast( - reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); - - if (work_tile_info.is_reduction_unit()) { - // Wait until the peers collaborating on this output tile have all written - // their accumulators to workspace. - BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, num_peers); - - separate_reduction(accumulators, num_barriers, group_reduction_workspace, barrier_group_thread_idx, num_peers, num_accumulator_mtxs); - } - else if (!compute_epilogue(work_tile_info, params)) { - if ( - params.requires_separate_reduction() - || work_tile_info.K_idx == 0 - ) { - // The first peer initializes the workspace partials in the non-separate-reduction case, - // and all peers write to their own location in workspace when using separate reduction - BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); - } - else { - if (params.reduction_mode_ == ReductionMode::Deterministic) { - // Wait until the preceding split added its accumulators - BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); - } - else { - // Wait until the first split has stored its accumulators. Note that the first split will have - // accumulated a value into the lock potentially greater than one (since the locked value is - // incremented by work_tile_info.k_tile_count below for both the deterministic and non-deterministic) - // cases. For non-deterministic reductions, all that non-first or last splits care about is whether - // the first split has been written, so we only wait while the locked value is less than 1. - BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); - } - - // Perform reduction in workspace - BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); - } - - // If separate reduction is being performed, each participating stream-K unit increments the barrier - // by only 1. Otherwise, increment by the K tile count that this unit has processed. - uint32_t increment = params.requires_separate_reduction() ? 1 : work_tile_info.k_tile_count; - - // Signal our arrival - if (idx_accumulator_mtxs == (num_accumulator_mtxs - 1)) { - BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment); - } - } - else { - // Wait until the preceding split added its accumulators - BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); - - // The block computing the final split for the tile adds previously-reduced partials - // to its accumulators and computes the epilogue. - BlockStripedReduceT::load_add(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); - } - } - - template - CUTLASS_DEVICE - static void - separate_reduction( - FrgTensorC& accumulators, - uint32_t num_barriers, - typename FrgTensorC::value_type* reduction_workspace, - uint32_t thread_idx, - uint64_t num_peers, - uint32_t num_accumulator_mtxs) { - using AccumulatorArrayT = Array; - using BlockStripedReduceT = BlockStripedReduce; - - AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); - - plus add_fragments; - uint64_t peer_offset = cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; - - for (uint64_t i = 0; i < num_peers; ++i) { - // Load peer fragment - AccumulatorArrayT addend_fragment; - auto peer_reduction_workspace = reinterpret_cast(reduction_workspace + (i * peer_offset)); - - BlockStripedReduceT::load_add(*accumulator_array, peer_reduction_workspace, thread_idx); - } - } - - // Returns whether the block assigned this work should compute the epilogue for the corresponding - // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. - CUTLASS_HOST_DEVICE - static bool - compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) { - // `is_final_split` will be set to `true` for the following scenarios, all of which must compute the epilogue: - // 1. The tile is computed in data-parallel mode - // 2. The tile is computed in split-/stream-K mode and this work unit represents the final split of the tile - // 3. The tile is computed in split-/stream-K mode and separate reduction is used, and this is a separate reduction unit - return work_tile_info.is_valid() && - (work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor) && - !params.requires_separate_reduction()) || work_tile_info.is_separate_reduction; - } - - // Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K] - CUTLASS_DEVICE - static uint64_t - output_tile_index(Params const& params, WorkTileInfo const& work_tile_info) { - uint64_t linear_idx_in_batch = UnderlyingScheduler::get_linear_idx_from_m_and_n( - work_tile_info.M_idx, work_tile_info.N_idx, - params.divmod_cluster_shape_major_, - params.divmod_cluster_shape_minor_, - params.divmod_cluster_blk_major_, - params.log_swizzle_size_, - params.raster_order_ - ); - - uint64_t tiles_mn = params.divmod_batch_.divisor; - return tiles_mn * work_tile_info.L_idx + linear_idx_in_batch; - } - - template - static size_t - get_workspace_size( - Arguments const& args, - ProblemShape problem_shape, - KernelHardwareInfo const& hw_info, - uint32_t mma_warp_groups, - const uint32_t epilogue_subtile = 1, - [[maybe_unused]] uint32_t num_accumulator_mtxs = 1) { - - auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); - - ClusterShape cluster_shape; - TileShape tile_shape; - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); - - return Params::get_workspace_size( - problem_blocks, - k_tile_per_output_tile, - to_gemm_coord(tile_shape), - to_gemm_coord(cluster_shape), - hw_info, - args.splits, - args.max_swizzle_size, - args.raster_order, - args.decomposition_mode, - args.reduction_mode, - mma_warp_groups, - sizeof_bits::value, - sizeof_bits::value, - epilogue_subtile - ); - } - - template - static cutlass::Status - initialize_workspace( - Arguments const& args, - void* workspace, - cudaStream_t stream, - ProblemShape const& problem_shape, - KernelHardwareInfo const& hw_info, - uint32_t mma_warp_groups, - const uint32_t epilogue_subtile = 1, - [[maybe_unused]] uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter* cuda_adapter = nullptr) { - - auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); - - ClusterShape cluster_shape; - TileShape tile_shape; - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); - - return Params::initialize_workspace( - workspace, - stream, - problem_blocks, - k_tile_per_output_tile, - to_gemm_coord(tile_shape), - to_gemm_coord(cluster_shape), - hw_info, - args.splits, - args.max_swizzle_size, - args.raster_order, - args.decomposition_mode, - args.reduction_mode, - mma_warp_groups, - sizeof_bits::value, - sizeof_bits::value, - epilogue_subtile, - 1, - cuda_adapter - ); - } - - template - CUTLASS_HOST_DEVICE - static uint32_t - get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape, TileShape) { - return work_tile_info.k_tile_count; - } - - CUTLASS_HOST_DEVICE - static uint32_t - get_work_k_tile_start(WorkTileInfo const& work_tile_info) { - return work_tile_info.K_idx; - } - - // Kernel helper function to get next work tile - CUTLASS_DEVICE - auto - fetch_next_work(WorkTileInfo work_tile_info) { - if (continue_current_work(work_tile_info)) { - return cute::make_tuple(work_tile_info, true); - } - - advance_to_next_work(); - return cute::make_tuple(get_current_work(), true); - } - - // Kernel helper function to get next work tile - template - CUTLASS_DEVICE - auto - fetch_next_work( - WorkTileInfo work_tile_info, - TileSchedulerPipeline& scheduler_pipeline, - TileSchedulerPipelineState scheduler_pipe_consumer_state) { - return fetch_next_work(work_tile_info); - } - - // Returns the initial work tile info that will be computed over - CUTLASS_DEVICE - WorkTileInfo - initial_work_tile_info(ClusterShape) { - return get_current_work(); - } - - // Given raster order and current work tile linear index, reset cta m and n index in the cluster. - CUTLASS_DEVICE - static dim3 - get_current_work_cta_m_n_in_cluster( - Params const& params, - uint64_t linear_idx, - dim3 block_id_in_cluster) { - auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = block_id_in_cluster; - uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); - uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); - - // Determine the CTA's M and N offsets within the preferred cluster - // This simply finds the linear offset of the CTA within the cluster, and takes a divmod - // on it depending on the rasterization order used by the scheduler. - uint64_t cluster_linear_work_idx_tmp = params.div_cluster_size(linear_idx) * params.get_cluster_size(); - - if (params.raster_order_ == RasterOrder::AlongN) { - params.divmod_cluster_shape_minor_(cta_n_in_cluster, cta_m_in_cluster, linear_idx - cluster_linear_work_idx_tmp); - } - else { - params.divmod_cluster_shape_minor_(cta_m_in_cluster, cta_n_in_cluster, linear_idx - cluster_linear_work_idx_tmp); - } - - return {static_cast(cta_m_in_cluster), static_cast(cta_n_in_cluster), _}; - } - -private: - - CUTLASS_DEVICE - static uint32_t - get_current_work_iter_start_possible_update_work_tile_k_remaining( - Params const& params, - uint64_t linear_idx, - WorkTileInfo& work_tile_info) { - // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K - // threadblock individually. For the most part, the set of K iterations corresponding to stream-K - // work was divided amongst stream-K threadblocks, and a threadblock determined which tile - // it would compute a (potentially-partial) output tile for based on the space of k iterations - // assigned to it. This often results in stream-K threadblocks processing tiles with different - // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the - // (generally few) waves of threadblocks assigned to compute stream-K work. - // - // With the introduction of threadblock clusters, there is additional benefit to maintaining - // locality in the K dimension: shared portions of operands can be multicasted to threadblocks - // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to - // threadblocks respects the ability to perform multicasting. - // - // To do so, we divide up the linearized stream-K units into clusters and share the same K - // offsets for work within clusters. - uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); - - uint64_t group_idx; - params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); - - // Determine whether we are in a "big group" that will process an additional - // stream-K cluster tile. - uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); - uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); - if (group_idx < params.big_groups_) { - ++sk_cluster_tiles_in_group; - } - - // Determine whether we are in a "big unit" within the group, that will process - // an additional K chunk in the group. - uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); - uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); - uint64_t big_units_in_group = params.div_cluster_size( - k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); - - uint64_t split; - params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); - - bool is_split_k = params.divmod_splits_.divisor > 1; - uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; - uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; - uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; - uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; - - // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + - (k_tiles_per_split * split); - - // Adjust the starting position and number of k iterations for "big units," which - // compute one extra iteration. If there are any big units, they will be the first - // in the linearized ID space. - auto k_tiles_in_my_split = k_tiles_per_split; - if (big_unit_cmp_lhs < big_unit_cmp_rhs) { - // Since the "big units" are the first units in the linearized ID space, each - // of the units preceding this big unit computed one extra iteration. Thus, - // we must offset our start iteration by the number of units that precede - // the current unit in the linearized ID space. - unit_iter_start += big_unit_cmp_lhs; - ++k_tiles_in_my_split; - } - else { - // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += big_unit_cmp_rhs; - } - if (!is_split_k) { - // Adjust the unit starting position and number of tiles to avoid - // computing splits of size less than min_iters_per_sk_unit_ - int unused, start_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); - if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another - // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take over these K tiles. - unit_iter_start -= start_tile_k_tile; - k_tiles_in_my_split += start_tile_k_tile; - } - else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); - unit_iter_start += adjustment_tiles; - k_tiles_in_my_split -= adjustment_tiles; - } - else if (params.ktile_start_alignment_count_ == 2 && start_tile_k_tile % 2 != 0) { - // ktile for each SM start from even number - // If start from odd number ktile within the output tile - // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // if end on odd number ktile within the output tile - // now end at ktile that one before my ktile end (give one ktile to next sm) - unit_iter_start -= 1; - k_tiles_in_my_split += 1; - } - } - if (work_tile_info.k_tile_count == 0) { - // This is a new unit - - if (!is_split_k) { - // - // Adjust the unit ending position and number of tiles to avoid - // computing splits of size less than min_iters_per_sk_unit_ - // - - // Begin by assuming that no adjustment is needed - auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; - - int unused, end_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); - - if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - k_tiles_in_my_split -= end_tile_k_tile; - } - else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take on these K tiles. - k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); - } - else if (params.ktile_start_alignment_count_ == 2 && end_tile_k_tile % 2 != 0) { - // ktile for each SM start from even number - // If start from odd number ktile within the output tile - // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // If end on odd number ktile within the output tile, - // now end at ktile that one before my ktile end (give one ktile to next sm) - k_tiles_in_my_split -= 1; - } - } - - work_tile_info.k_tile_remaining = k_tiles_in_my_split; - } - return unit_iter_start; - } - - // Update output tile index given existing remaining k tiles of current work tile. - CUTLASS_DEVICE - static uint64_t update_output_tile_id_and_work_tile_k( - Params const& params, - WorkTileInfo& work_tile_info, - uint64_t linear_idx, - uint32_t unit_iter_start, - uint64_t cta_m_in_cluster, - uint64_t cta_n_in_cluster) { - // we divide up the linearized stream-K units into clusters and share the same K - // offsets for work within clusters. - uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); - - uint64_t unused, group_idx; - params.divmod_sk_groups_(unused, group_idx, cluster_linear_work_idx); - - uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; - - // Find the output tile corresponding to the final k tile covered by this - // work unit. Stream-K work units will work backwards in terms of the tiles they - // are responsible computing. This is beneficial because the final (partial) - // tile computed by a stream-K block is typically the beginning of the output - // tile, while the beginning (partial) tile is typically the ending of another - // output tile. Since ending portions of an output tile must reduce across - // other work units computing portions of that output tile, it is preferable - // for them to be computed later, so as to reduce the likelihood of blocking - // on other work. - - auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); - uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; - - // Convert the output tile from the linearized space within each group to the - // overall linearized space. - uint64_t output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; - - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id *= params.get_cluster_size(); - - // The final linearized tile ID is in units of the cluster dimension over which we rasterize. - if (params.raster_order_ == RasterOrder::AlongN) { - output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - else { - output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - // The unit's starting k iteration in the current tile is either the starting - // iteration for the tile as a whole, or the starting k iteration for the unit - // as a whole (if the latter is greater than the former). - uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); - - // Similarly, the unit's ending k iteration (exclusive) is either the end of - // the current tile it is assigned, or the ending iteration of the unit as a whole - // (if the latter is less than the former). - uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); - - // Set the k offset to be the starting k tile for this output tile - work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); - work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; - - return output_tile_id; - } - // Given output tile index, update M, N, L index of current work tile info. - CUTLASS_DEVICE - static void - update_work_tile_m_n_l( - Params const& params, - uint32_t output_tile_id, - WorkTileInfo& work_tile_info, - uint64_t cta_m_in_cluster, - uint64_t cta_n_in_cluster) { - - uint64_t work_idx_l, remainder; - params.divmod_batch_(work_idx_l, remainder, output_tile_id); - - uint64_t cta_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder); - - auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( - cta_per_grid_dim, - params.divmod_cluster_shape_major_, - params.divmod_cluster_shape_minor_, - params.divmod_cluster_blk_major_, - params.log_swizzle_size_, - params.raster_order_ - , cta_m_in_cluster - , cta_n_in_cluster - ); - - // Set the M, N, and L block offsets - work_tile_info.M_idx = work_idx_m; - work_tile_info.N_idx = work_idx_n; - work_tile_info.L_idx = static_cast(work_idx_l); - } - - // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info - // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining - // iterations) is used to find the next tile in the current work unit. - CUTLASS_DEVICE - static void - assign_work( - Params const& params, - uint64_t linear_idx, - dim3 block_id_in_cluster, - WorkTileInfo& work_tile_info, - uint32_t &unit_iter_start) { - - auto [cta_m_in_cluster, cta_n_in_cluster, _] = - get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); - - uint64_t output_tile_id = linear_idx; - if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { - // Separate-reduction work - auto cluster_size = params.get_cluster_size(); - // Divide up the linearized separate reduction units into clusters - uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); - uint64_t cluster_tile_idx, epi_subtile_idx; - params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id = cluster_tile_idx * cluster_size; - - work_tile_info.setup_separate_reduction(epi_subtile_idx); - } - else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { - // Data-parallel work - output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; - work_tile_info.K_idx = 0; - work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; - work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; - } - else { - unit_iter_start = get_current_work_iter_start_possible_update_work_tile_k_remaining(params, linear_idx, work_tile_info); - output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, - linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); - } - update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); - } - - // The fast path to get current output tile index then update fields of work tile info - // when continuing current work tile is needed, since k tile starting index has precomputed - // in the first time fetching current work tile. - CUTLASS_DEVICE - static void - fast_assign_work( - uint32_t unit_iter_start, - Params const& params, - uint64_t linear_idx, - dim3 block_id_in_cluster, - WorkTileInfo& work_tile_info) { - - auto [cta_m_in_cluster, cta_n_in_cluster, _] = - get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); - - uint64_t output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, - linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); - - update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); - } - - // Returns the starting and ending peer ID of this tile - CUTLASS_HOST_DEVICE - static auto - tile_peer_range(Params const& params, uint32_t tile_idx, WorkTileInfo const& work_tile_info) { - uint32_t cur_k_tile = static_cast(work_tile_info.K_idx); - uint32_t tile_idx_in_cluster_path = params.div_cluster_size(tile_idx); - uint32_t start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path; - uint32_t end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1; - uint32_t big_unit_k_tiles = params.big_units_ * (params.divmod_k_tiles_per_sk_unit_.divisor + 1); - - auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t unit_k_start, uint32_t unit_k_end) { - if (k_tile - start_k_tile < Params::min_iters_per_sk_unit_ && - unit_k_end - start_k_tile < Params::min_iters_per_sk_unit_) { - // k_tile is within the first min_iters_per_sk_unit_ K tiles of this output tile, - // and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this - // output tile. This work will thus be subsumed by the next stream-K unit. - ++unit_idx; - } - - if (end_k_tile + 1 - k_tile < Params::min_iters_per_sk_unit_ && - end_k_tile + 1 - unit_k_start < Params::min_iters_per_sk_unit_) { - // k_tile is within the last min_iters_per_sk_unit_ K tiles of this output tile, - // and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this - // output tile. This work will thus be subsumed by the previous stream-K unit. - --unit_idx; - } - return unit_idx; - }; - - // Lambda to find the ID of the stream-K unit that computes this K tile - auto find_unit = [&](uint32_t k_tile) { - if (k_tile < big_unit_k_tiles) { - // The tile is within the "big unit range" - uint32_t unit_idx = params.divmod_k_tiles_per_sk_big_unit_.divide(k_tile); - uint32_t unit_k_start = unit_idx * params.divmod_k_tiles_per_sk_big_unit_.divisor; - uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_big_unit_.divisor; - return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); - } - else { - // The tile is after the "big unit range." Account for this by finding the "normal unit" - // that it belongs to, and then offsetting by the number of big units - uint32_t unit_idx_after_big_units = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles); - uint32_t unit_k_start = unit_idx_after_big_units * params.divmod_k_tiles_per_sk_unit_.divisor + (params.big_units_ * params.divmod_k_tiles_per_sk_big_unit_.divisor); - uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_unit_.divisor; - uint32_t unit_idx = unit_idx_after_big_units + params.big_units_; - return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); - } - }; - - return cute::make_tuple(find_unit(start_k_tile), find_unit(start_k_tile + cur_k_tile), find_unit(end_k_tile)); - } -}; - -} // namespace cutlass::gemm::kernel::detail diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h deleted file mode 100644 index 84102a6c933fcc6e80604ccd232db8ca033c0d56..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h +++ /dev/null @@ -1,394 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/params_sparse_base.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -> -struct SparseGemm { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using OutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - static int const kSparse = Mma::kSparse; - static int const kMetaSizeInBits = Mma::kMetaSizeInBits; - static int const kMaxID2 = Mma::kMaxID2; - static int const kElementsPerElementE = Mma::kElementsPerElementE; - - using ElementE = typename Mma::ElementE; - using LayoutE = typename Mma::LayoutE; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ParamsA = typename Mma::IteratorA::Params; - using TensorRefA = typename Mma::IteratorA::TensorRef; - using ParamsB = typename Mma::IteratorB::Params; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using ParamsE = typename Mma::IteratorE::Params; - using TensorRefE = typename Mma::IteratorE::TensorRef; - - /// Parameters structure - struct Params : public SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE> { - - using Base = SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE>; - - // - // Data members - // - - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename OutputOp::Params output_op; - int *semaphore; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - TensorRefA ref_A, - TensorRefB ref_B, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - TensorRefE ref_E, - typename OutputOp::Params output_op = typename OutputOp::Params(), - int *workspace = nullptr - ): - Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK), - params_C(ref_C.layout()), - ref_C(ref_C), - params_D(ref_D.layout()), - ref_D(ref_D), - output_op(output_op) { - semaphore = workspace; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - SparseGemm() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - typename Mma::IteratorE::TensorRef ref_E) { - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; - - if (!TensorRef_aligned(ref_A, kAlignmentA)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_B, kAlignmentB)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_C, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_D, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_E, kAlignmentE)) { - return Status::kErrorMisalignedOperand; - } - - if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) || - (problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) { - - return Status::kErrorMisalignedOperand; - } - - // The k dimension has to be the multiple of the Threadblock k because out - // of bound meta data would be initialized to 0 by acync.zfill but 0 is not - // a valid meta data. - if (problem_size.k() % Mma::Shape::kK) { - return Status::kErrorMisalignedOperand; - } - - // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) - // because of the row reordering of operand E - static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; - - if (problem_size.m() % kAlignmentM) { - return Status::kErrorMisalignedOperand; - } - - return Status::kSuccess; - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - cutlass::MatrixCoord tb_offset_B{ - threadblock_tile_offset.k() * params.gemm_k_size, - threadblock_tile_offset.n() * Mma::Shape::kN - }; - - cutlass::MatrixCoord tb_offset_E{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min( - params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A, B, and E operands - typename Mma::IteratorA iterator_A( - params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k / kSparse}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, - params.ref_B.data(), - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); - - typename Mma::IteratorE iterator_E( - params.params_E, params.ref_E.data(), - {params.problem_size.m(), - problem_size_k / kSparse / kElementsPerElementE}, - thread_idx, tb_offset_E); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); - } - - // - // Epilogue - // - - OutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - params.ref_C.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - params.ref_D.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - - __threadfence(); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - __threadfence(); - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h deleted file mode 100644 index 0574c21823be1b492abb5dc1766ee87a4f12d8bd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_absmax.h +++ /dev/null @@ -1,509 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Sparse GEMM kernel with an epilogue that computes the absolute maximum value of the output - and a pre-activation-function auxiliary output. The auxiliary output is also (optionally) - stored to global memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/params_sparse_base.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -> -struct SparseGemmWithAbsmax { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using OutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - static int const kSparse = Mma::kSparse; - static int const kMetaSizeInBits = Mma::kMetaSizeInBits; - static int const kMaxID2 = Mma::kMaxID2; - static int const kElementsPerElementE = Mma::kElementsPerElementE; - - using ElementE = typename Mma::ElementE; - using LayoutE = typename Mma::LayoutE; - - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ParamsA = typename Mma::IteratorA::Params; - using TensorRefA = typename Mma::IteratorA::TensorRef; - using ParamsB = typename Mma::IteratorB::Params; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using ParamsE = typename Mma::IteratorE::Params; - using TensorRefE = typename Mma::IteratorE::TensorRef; - - using ParamsC = typename Epilogue::OutputTileIterator::Params; - using TensorRefC = typename Epilogue::OutputTileIterator::TensorRef; - using ParamsD = typename Epilogue::OutputTileIterator::Params; - using TensorRefD = typename Epilogue::OutputTileIterator::TensorRef; - using ParamsAux = typename Epilogue::AuxOutputTileIterator::Params; - using TensorRefAux = typename Epilogue::AuxOutputTileIterator::TensorRef; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmCoord problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefC ref_C; - TensorRefD ref_D; - TensorRefE ref_E; - TensorRefAux ref_Aux; - void* ptr_Vector; - typename LayoutC::Stride::Index ldr; - - typename Epilogue::OutputOp::Params epilogue; - int split_k_slices; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments(): problem_size(0, 0, 0), split_k_slices(1) { - - } - - /// Constructs an Arguments structure - CUTLASS_HOST_DEVICE - Arguments( - GemmCoord problem_size_, - TensorRefA ref_A_, - TensorRefB ref_B_, - TensorRefC ref_C_, - TensorRefD ref_D_, - TensorRefE ref_E_, - TensorRefAux ref_Aux_, - void* ptr_Vector_, - typename LayoutC::Stride::Index ldr_, - typename OutputOp::Params epilogue_ = - typename OutputOp::Params(), - int split_k_slices = 1 - ): - problem_size(problem_size_), - ref_A(ref_A_), - ref_B(ref_B_), - ref_C(ref_C_), - ref_D(ref_D_), - ref_E(ref_E_), - ref_Aux(ref_Aux_), - ptr_Vector(ptr_Vector_), - ldr(ldr_), - epilogue(epilogue_), - split_k_slices(split_k_slices) { - - } - }; - - /// Parameters structure - struct Params : public SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE> { - - using Base = SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE>; - - // - // Data members - // - - ParamsC params_C; - TensorRefC ref_C; - ParamsD params_D; - TensorRefD ref_D; - ParamsAux params_Aux; - TensorRefAux ref_Aux; - - void* ptr_Vector; - typename LayoutC::Stride::Index ldr; - - typename OutputOp::Params output_op; - int *semaphore; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - TensorRefA ref_A, - TensorRefB ref_B, - TensorRefC ref_C, - TensorRefD ref_D, - TensorRefE ref_E, - TensorRefAux ref_Aux, - void* ptr_Vector, - typename LayoutC::Stride::Index ldr, - typename OutputOp::Params output_op = typename OutputOp::Params(), - int *workspace = nullptr - ): - Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK), - params_C(ref_C.layout()), - ref_C(ref_C), - params_D(ref_D.layout()), - ref_D(ref_D), - output_op(output_op), - ref_Aux(ref_Aux), - params_Aux(ref_Aux.layout()), - ptr_Vector(ptr_Vector), - ldr(ldr) { - semaphore = workspace; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - SparseGemmWithAbsmax() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - typename Mma::IteratorE::TensorRef ref_E) { - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; - - if (!TensorRef_aligned(ref_A, kAlignmentA)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_B, kAlignmentB)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_C, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_D, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(ref_E, kAlignmentE)) { - return Status::kErrorMisalignedOperand; - } - - if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) || - (problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) { - - return Status::kErrorMisalignedOperand; - } - - // The k dimension has to be the multiple of the Threadblock k because out - // of bound meta data would be initialized to 0 by acync.zfill but 0 is not - // a valid meta data. - if (problem_size.k() % Mma::Shape::kK) { - return Status::kErrorMisalignedOperand; - } - - // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) - // because of the row reordering of operand E - static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; - - if (problem_size.m() % kAlignmentM) { - return Status::kErrorMisalignedOperand; - } - - return Status::kSuccess; - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - cutlass::MatrixCoord tb_offset_B{ - threadblock_tile_offset.k() * params.gemm_k_size, - threadblock_tile_offset.n() * Mma::Shape::kN - }; - - cutlass::MatrixCoord tb_offset_E{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min( - params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A, B, and E operands - typename Mma::IteratorA iterator_A( - params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k / kSparse}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, - params.ref_B.data(), - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); - - typename Mma::IteratorE iterator_E( - params.params_E, params.ref_E.data(), - {params.problem_size.m(), - problem_size_k / kSparse / kElementsPerElementE}, - thread_idx, tb_offset_E); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); - } - - // - // Epilogue - // - - OutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - typename Epilogue::ElementVector *ptr_Vector = static_cast(params.ptr_Vector); - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - params.ref_C.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - params.ref_D.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to auxiliary destination tensor. - typename Epilogue::AuxOutputTileIterator iterator_Aux( - params.params_Aux, - // Only the final block writes the auxiliary tensor - ((kSplitKSerial && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : params.ref_Aux.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - - __threadfence(); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - // Only the final block uses Vector - ((kSplitKSerial && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Vector, - iterator_D, - accumulators, - iterator_C, - iterator_Aux, - params.problem_size.mn(), - threadblock_offset); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - __threadfence(); - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h deleted file mode 100644 index a8ec1c3dc091dd5d14a2b1c1d71897b7af272546..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h +++ /dev/null @@ -1,238 +0,0 @@ - -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Sparse GEMM with visitor. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/kernel/sparse_gemm.h" -#include "cutlass/gemm/kernel/params_sparse_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Sparse Gemm that compute the epilogue visitor functor -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_ ///! Threadblock swizzling function -> -struct SparseGemmWithEpilogueVisitor : public SparseGemm { - - using Base = SparseGemm; - - using Mma = Mma_; - using Epilogue = Epilogue_; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using FusionCallbacks = typename Epilogue::FusionCallbacks; - - using ParamsA = typename Mma::IteratorA::Params; - using TensorRefA = typename Mma::IteratorA::TensorRef; - using ParamsB = typename Mma::IteratorB::Params; - using TensorRefB = typename Mma::IteratorB::TensorRef; - using ParamsE = typename Mma::IteratorE::Params; - using TensorRefE = typename Mma::IteratorE::TensorRef; - - static int const kSparse = Base::kSparse; - static int const kElementsPerElementE = Base::kElementsPerElementE; - using SharedStorage = typename Base::SharedStorage; - - /// Parameters structure - struct Params : public SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE> { - - using Base = SparseParamsBase< - ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, - ParamsE, TensorRefE>; - - // - // Data members - // - - typename FusionCallbacks::Params output_op; - cute::Shape problem_shape; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorE::TensorRef ref_E, - typename FusionCallbacks::Arguments output_op = typename FusionCallbacks::Arguments() - ): - Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK), - output_op(FusionCallbacks::to_underlying_arguments(problem_size, output_op, nullptr /*workspace*/)), - problem_shape(problem_size.m(), problem_size.n(), 1) { - } - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - SparseGemmWithEpilogueVisitor() { } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - cutlass::MatrixCoord tb_offset_B{ - threadblock_tile_offset.k() * params.gemm_k_size, - threadblock_tile_offset.n() * Mma::Shape::kN - }; - - cutlass::MatrixCoord tb_offset_E{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size / kSparse, - }; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min( - params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A, B, and E operands - typename Mma::IteratorA iterator_A( - params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k / kSparse}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, - params.ref_B.data(), - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); - - typename Mma::IteratorE iterator_E( - params.params_E, params.ref_E.data(), - {params.problem_size.m(), - problem_size_k / kSparse / kElementsPerElementE}, - thread_idx, tb_offset_E); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (gemm_k_iterations > 0) { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); - } - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // - // Epilogue - // - - Epilogue epilogue( - params.output_op, - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp deleted file mode 100644 index f8319b1157b1e6c9df5be1b444e9d3813a1a2bae..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/static_tile_scheduler.hpp +++ /dev/null @@ -1,513 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/fast_math.h" -#include "cutlass/gemm_coord.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/pipeline/pipeline.hpp" -namespace cutlass::gemm::kernel::detail { - -/////////////////////////////////////////////////////////////////////////////// - -// Users are not supposed to use this class directly. -// This is a CRTP base class for the actual tile schedulers. -template -class StaticPersistentTileScheduler { - -private: - uint64_t current_work_linear_idx_; - uint64_t total_grid_size_; - -public: - struct WorkTileInfo { - int32_t M_idx = 0; - int32_t N_idx = 0; - int32_t L_idx = 0; - bool is_valid_tile = false; - - CUTLASS_HOST_DEVICE - bool - is_valid() const { - return is_valid_tile; - } - - CUTLASS_HOST_DEVICE - static WorkTileInfo - invalid_work_tile() { - return {-1, -1, -1, false}; - } - - CUTLASS_HOST_DEVICE - bool - is_final_split(uint32_t k_tiles_per_output_tile) const { - return true; - } - - CUTLASS_HOST_DEVICE - int32_t - reduction_subtile_idx() const { - return -1; - } - }; - - using Params = PersistentTileSchedulerSm90Params; - using RasterOrder = typename Params::RasterOrder; - using RasterOrderOptions = typename Params::RasterOrderOptions; - static constexpr bool IsDynamicPersistent = false; - -public: - struct Arguments { - int max_swizzle_size = 1; - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; - }; - - template - static Params - to_underlying_arguments( - ProblemShapeMNKL problem_shape_mnkl, - TileShape tile_shape, - ClusterShape cluster_shape, - [[maybe_unused]] KernelHardwareInfo const& hw_info, - Arguments const& arguments, - [[maybe_unused]] void* workspace=nullptr, - [[maybe_unused]] const uint32_t epilogue_subtile = 1, - [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) { - - // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic - static_assert(cute::is_static::value); - static_assert(cute::is_static::value); - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - - Params params; - params.initialize( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order - ); - - return params; - } - - CUTLASS_HOST_DEVICE - static bool - can_implement(Arguments const& args) { - return args.max_swizzle_size >= 0; - } - - CUTLASS_HOST_DEVICE - StaticPersistentTileScheduler() { } - - CUTLASS_DEVICE explicit StaticPersistentTileScheduler(Params const& params_) : scheduler_params(params_) { - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) - if (params_.raster_order_ == RasterOrder::AlongN) { - current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); - } - else { - current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); - } - - total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif - } - - // Returns the initial work tile info that will be computed over - template - CUTLASS_DEVICE - WorkTileInfo - initial_work_tile_info(ClusterShape cluster_shape) { - return get_current_work(); - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_); - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx) const { - if (linear_idx >= scheduler_params.blocks_per_problem_) { - return WorkTileInfo::invalid_work_tile(); - } - - // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices - uint64_t work_idx_l, remainder; - scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx); - - uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder); - - auto [work_idx_m, work_idx_n] = Subclass::get_work_idx_m_and_n(blk_per_grid_dim, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.divmod_cluster_blk_major_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); - - return {work_idx_m, work_idx_n, static_cast(work_idx_l), true}; - } - - CUTLASS_DEVICE - void - advance_to_next_work(uint32_t advance_count = 1) { - current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); - } - - CUTLASS_DEVICE - bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const { - if (continue_current_work(work_tile_info)) { - return false; - } - return not get_current_work_for_linear_idx( - current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count)) - ).is_valid(); - } - - // Computes the linear index within a batch given M and N tile offsets within the batch. - // This essentially inverts the mapping performed in get_work_idx_m_and_n - static CUTLASS_DEVICE - uint64_t - get_linear_idx_from_m_and_n( - int32_t tile_m, - int32_t tile_n, - FastDivmodU64Pow2 const& divmod_cluster_shape_major, - FastDivmodU64Pow2 const& divmod_cluster_shape_minor, - FastDivmodU64 const& divmod_cluster_blk_major, - int32_t log_swizzle_size, - RasterOrder raster_order) { - - uint64_t minor_work_idx, major_work_idx, cluster_minor_offset; - if (raster_order == RasterOrder::AlongN) { - minor_work_idx = static_cast(tile_m); - major_work_idx = static_cast(tile_n); - uint64_t cluster_m = divmod_cluster_shape_minor.divide(tile_m) * divmod_cluster_shape_minor.divisor; - cluster_minor_offset = tile_m - cluster_m; - } - else { - major_work_idx = static_cast(tile_m); - minor_work_idx = static_cast(tile_n); - uint64_t cluster_n = divmod_cluster_shape_minor.divide(tile_n) * divmod_cluster_shape_minor.divisor; - cluster_minor_offset = tile_n - cluster_n; - } - - uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset; - cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset); - divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx); - - uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size; - uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1); - - uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major; - - uint64_t cluster_id = (extra << log_swizzle_size) | offset; - return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset; - } - - // Given the inputs, computes the total number of output blocks over which this problem will compute. - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) { - auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape))); - auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape))); - - return Params::get_tiled_cta_shape_mnl( - to_gemm_coord(problem_shape_mnkl), - to_gemm_coord(cluster_shape), - cta_m, cta_n - ); - } - - // Reloaded interface that receives WorkTileInfo to deduce next work. - // Kernel helper function to get next work tile - CUTLASS_DEVICE - auto - fetch_next_work(WorkTileInfo work_tile_info) { - if (continue_current_work(work_tile_info)) { - return cute::make_tuple(work_tile_info, true); - } - - advance_to_next_work(); - return cute::make_tuple(get_current_work(), true); - } - - // Given the inputs, computes the total number of output blocks over which this problem will compute. - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, - TileShape tile_shape_mnk, - AtomThrShape atom_thr_shape_mnk, - ClusterShape cluster_shape_mnk) { - auto [tiles_m, tiles_n, tiles_l] = product_each(ceil_div(select<0,1,3>(problem_shape_mnkl), take<0,2>(tile_shape_mnk))); - auto cta_m = round_nearest(tiles_m * size<0>(atom_thr_shape_mnk), size<0>(cluster_shape_mnk)); - auto cta_n = round_nearest(tiles_n * size<1>(atom_thr_shape_mnk), size<1>(cluster_shape_mnk)); - - return Params::get_tiled_cta_shape_mnl( - to_gemm_coord(problem_shape_mnkl), - to_gemm_coord(cluster_shape_mnk), - cta_m, cta_n - ); - } - - // Kernel helper function to get next work tile - template - CUTLASS_DEVICE - auto - fetch_next_work( - WorkTileInfo work_tile_info, - TileSchedulerPipeline& scheduler_pipeline, - TileSchedulerPipelineState scheduler_pipe_consumer_state) { - return fetch_next_work(work_tile_info); - } - - CUTLASS_DEVICE - static auto - work_tile_to_cta_coord(WorkTileInfo work_tile_info) { - // Get every cta coord in three dimensions of the cluster - auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster(); - return make_coord( - work_tile_info.M_idx + static_cast(cta_m_in_cluster), - work_tile_info.N_idx + static_cast(cta_n_in_cluster), - _, - work_tile_info.L_idx + static_cast(cta_l_in_cluster) - ); - } - - CUTLASS_DEVICE - static auto - work_tile_to_cta_coord(WorkTileInfo work_tile_info, dim3 block_id_in_cluster) { - // Get every cta coord in three dimensions of the cluster - auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = block_id_in_cluster; - return make_coord( - work_tile_info.M_idx + static_cast(cta_m_in_cluster), - work_tile_info.N_idx + static_cast(cta_n_in_cluster), - _, - work_tile_info.L_idx + static_cast(cta_l_in_cluster) - ); - } - - // Given the inputs, computes the physical grid we should launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - [[maybe_unused]] Params const& params, - ProblemShapeMNKL problem_shape_mnk, - BlockShape cta_shape, - ClusterShape cluster_shape, - KernelHardwareInfo hw_info, - Arguments arguments = Arguments{}, - bool truncate_by_problem_size=true) { - - auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{}); - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order, - /* truncate_by_problem_size = */true - ); - } - - // Given the inputs, computes the physical grid we should launch. - template - static dim3 - get_grid_shape( - Params const& params, - ProblemShapeMNKL problem_shape_mnkl, - TileShape tile_shape_mnk, - AtomThrShape atom_thr_shape_mnk, - ClusterShape cluster_shape_mnk, - KernelHardwareInfo hw_info) { - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk); - Arguments args{}; - if constexpr (!std::is_const_v) { - args.max_swizzle_size = 1 << params.log_swizzle_size_; - } - args.raster_order = params.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM; - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape_mnk), - hw_info, - args.max_swizzle_size, - args.raster_order, - /* truncate_by_problem_size = */true - ); - } - - // Convert CTA-level work tile info to cluster-level tile coord - CUTLASS_DEVICE - auto - work_tile_to_cluster_coord_mnkl(WorkTileInfo work_tile_info) const { - // TileScheduler works at CTA-level, kernel works at cluster-level - int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_, - scheduler_params.problem_tiles_m_); - int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_, - scheduler_params.problem_tiles_n_); - int l_coord = idx2crd(work_tile_info.L_idx, - scheduler_params.problem_tiles_l_); - return make_coord(m_coord, n_coord, _, l_coord); - } - - // Returns whether the block assigned this work should compute the epilogue for the corresponding - // output tile. For the basic tile scheduler, this is always true. - CUTLASS_HOST_DEVICE - static bool - compute_epilogue(WorkTileInfo const&, Params const&) { - return true; - } - - CUTLASS_HOST_DEVICE - static bool - compute_epilogue(WorkTileInfo const&) { - return true; - } - - // Performs the reduction across splits for a given output tile. Since this scheduler does - // not split output tiles, no reduction is needed. - template - CUTLASS_DEVICE - static void - fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} - - // Performs the reduction across splits for a given output tile. No fixup is required for - // work units returned by this scheduler. - template - CUTLASS_DEVICE - void - fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) const { } - - // Returns whether the current WorkTileInfo passed in should continue to be used. Since - // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo - // passed in should not be used after having been processed. - CUTLASS_DEVICE - static bool - continue_current_work(WorkTileInfo&) { - return false; - } - - template - CUTLASS_DEVICE - auto - get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShapeMNKL problem_shape_MNKL, TileShape tile_shape, Shape) { - auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape)); - return cute::make_coord_iterator(k_tiles); - } - - template - CUTLASS_HOST_DEVICE - static int - get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) { - // All work units returned by this scheduler cover the entire K iteration - // space of the output tile assigned to the work unit. - return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); - } - - CUTLASS_HOST_DEVICE - static uint32_t - get_work_k_tile_start(WorkTileInfo const&) { - // All work units returned by this scheduler start from K tile 0 - return 0u; - } - - CUTLASS_DEVICE - static bool - need_separate_reduction(Params const& params) { - return false; - } - - CUTLASS_DEVICE - bool - is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { - return false; - } - - template - CUTLASS_DEVICE - void - separate_reduction( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - // Shares the accumulator set with peers in the global workspace - template - CUTLASS_DEVICE - static void - share( - Params const& params, - WorkTileInfo const& work_tile_info, - FrgTensorC& accumulators, - uint32_t num_barriers, - uint32_t barrier_idx) { - } - - CUTLASS_DEVICE - static bool - valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { - return true; - } - - CUTLASS_DEVICE - static bool - requires_separate_reduction(Params const& params) { - return false; - } - -public: - // Sink scheduler params as a member - Params scheduler_params; -}; - -} // namespace cutlass::gemm::kernel::detail diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/symm_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/symm_universal.h deleted file mode 100644 index 29cf977c66a46569849e53a48b9cce4a772b96d3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/symm_universal.h +++ /dev/null @@ -1,675 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" -#include "cutlass/semaphore.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma1_, ///! Threadblock-scoped triangular matrix multiply-accumulate (A*B or B*A) - typename Mma2_, ///! Threadblock-scoped triangular matrix multiply-accumulate (AT*B or B*AT) - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) - FillMode FillMode_ ///! Fill Mode for triangular matrix (kLower or kUpper) -> -struct SymmUniversal { -public: - - using Mma1 = Mma1_; - using Mma2 = Mma2_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma1::IteratorA::Element; - using ElementB = typename Mma1::IteratorB::Element; - - // Mma1 (TRMM - with diagonal: C_tmp = alpha * A * B) - using LayoutA = typename Mma1::IteratorA::Layout; - using LayoutBT = typename Mma1::IteratorB::Layout; - static ComplexTransform const kMma1TransformA = Mma1::kTransformA; - static ComplexTransform const kMma1TransformB = Mma1::kTransformB; - - // Mma2 (TRMM - withOUT diagonal: alpha * AT * B) - using LayoutB = typename Mma2::IteratorA::Layout; - using LayoutAT = typename Mma2::IteratorB::Layout; - static ComplexTransform const kMma2TransformA = Mma2::kTransformA; - static ComplexTransform const kMma2TransformB = Mma2::kTransformB; - - // Common type definitions for Mma1 and Mma2 - using Operator = typename Mma1::Operator; - using OperatorClass = typename Mma1::Operator::OperatorClass; - using ThreadblockShape = typename Mma1::Shape; - using WarpShape = typename Mma1::Operator::Shape; - using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; - using ArchTag = typename Mma1::ArchTag; - - static int const kStages = Mma1::kStages; - static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; - - // Output related typedefinitions - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static SideMode const kSideModeA = SideMode_; - static FillMode const kFillModeA = FillMode_; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma1::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - - // - // Structures - // - - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmUniversalMode mode = GemmUniversalMode::kGemm; - GemmCoord problem_size{}; - int batch_count{1}; - - typename EpilogueOutputOp::Params epilogue{}; - - void const * ptr_A{nullptr}; - void const * ptr_B{nullptr}; - void const * ptr_C{nullptr}; - void * ptr_D{nullptr}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - - typename LayoutA::Stride::Index lda{0}; - typename LayoutB::Stride::Index ldb{0}; - typename LayoutC::Stride::Index ldc{0}; - typename LayoutC::Stride::Index ldd{0}; - - // - // Methods - // - - Arguments() = default; - - /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C, - void * ptr_D, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(0), - batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { - - } - - /// Returns arguments for the transposed problem sizes - Arguments transposed_problem_size() const { - Arguments args(*this); - - std::swap(args.problem_size.m(), args.problem_size.n()); - - return args; - } - - /// Returns arguments for the transposed matrices - Arguments swapped_matrices() const { - Arguments args(*this); - - std::swap(args.ptr_A, args.ptr_B); - std::swap(args.lda, args.ldb); - std::swap(args.batch_stride_A, args.batch_stride_B); - - return args; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - - cutlass::gemm::GemmCoord problem_size{}; - cutlass::gemm::GemmCoord grid_tiled_shape{}; - int swizzle_log_tile{0}; - - // Mma1 Iterator A and B params - typename Mma1::IteratorA::Params params_A_mma1{}; - typename Mma1::IteratorB::Params params_B_mma1{}; - - // Mma2 Iterator A and B params - typename Mma2::IteratorA::Params params_A_mma2{}; - typename Mma2::IteratorB::Params params_B_mma2{}; - - typename Epilogue::OutputTileIterator::Params params_C{}; - typename Epilogue::OutputTileIterator::Params params_D{}; - - typename EpilogueOutputOp::Params output_op{}; - - GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; - int batch_count {0}; - int gemm_k_size {0}; - - void * ptr_A{nullptr}; - void * ptr_B{nullptr}; - void * ptr_C{nullptr}; - void * ptr_D{nullptr}; - - int64_t batch_stride_A {0}; - int64_t batch_stride_B {0}; - int64_t batch_stride_C {0}; - int64_t batch_stride_D {0}; - - int *semaphore{nullptr}; - - // - // Methods - // - Params() = default; - - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A_mma1(args.lda), - params_B_mma1(args.ldb), - params_A_mma2(args.lda), - params_B_mma2(args.ldb), - params_C(args.ldc), - params_D(args.ldd), - output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_D(const_cast(args.ptr_D)), - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - semaphore(static_cast(workspace)) { - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { - - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); - ptr_D = args.ptr_D; - - output_op = args.epilogue; - - semaphore = static_cast(workspace); - } - - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma1::SharedStorage mma1_main_loop; - typename Mma2::SharedStorage mma2_main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: - - // - // Methods - // - - CUTLASS_DEVICE - SymmUniversal() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) { - - static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { - - return Status::kErrorMisalignedOperand; - } - - return Status::kSuccess; - } - - static Status can_implement(Arguments const &args) { - return can_implement(args.problem_size); - } - - /// Executes two GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA *ptr_A = static_cast(params.ptr_A); - ElementB *ptr_B = static_cast(params.ptr_B); - - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - - __syncthreads(); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_MxK_mma1{ - threadblock_tile_offset.m() * Mma1::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_KxN_mma1{ - offset_k, - threadblock_tile_offset.n() * Mma1::Shape::kN - }; - - cutlass::MatrixCoord tb_offset_MxK_mma2{ - threadblock_tile_offset.m() * Mma1::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_KxN_mma2{ - offset_k, - threadblock_tile_offset.n() * Mma1::Shape::kN - }; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply for Mma1 - Mma1 mma1(shared_storage.mma1_main_loop, thread_idx, warp_idx, lane_idx); - - // Construct thread-scoped matrix multiply for Mma2 - Mma2 mma2(shared_storage.mma2_main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma1::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; - int gemm_k_iterations_mma1 = gemm_k_iterations; - int gemm_k_iterations_mma2 = gemm_k_iterations; - - - /****************************************************************************************************** - * SYMM (Side Mode, Fill Mode) is made of two TRMMs: - First TRMM (Mma1: Side Mode, Fill Mode, Non-Unit Diag): (A * B) or (B * A) - Second TRMM (Mma2: Side Mode, Inverted Fill Mode, Unit Diag): (AT * B) or (B * AT) - - * For the first TRMM (Mma1) of SYMM, the following method is used to calculate the k-iterations: - First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other - - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations - needed to process all elements till that coordinate. - - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations - needed to process all elements till that coordinate. - - Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other - - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations - that can be skipped for all elements of this tile. - - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations - that can be skipped for all elements of this tile. - - * For the second TRMM (Mma2) of SYMM, the k-iterations and threadblock offsets are calculated - the same way as the first TRMM (Mma1) of same side mode but with inverted fill mode. - For example, if the first TRMM is left sided with lower fill, the second TRMM would be - left sided with upper fill. - ********************************************************************************************************/ - - if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kLower) { - - int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { - gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; - } - - int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma2 != 0) { - tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); - tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); - gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; - } - - } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kUpper) { - - int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { - gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; - } - - int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma2 != 0) { - tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); - tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); - gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; - } - - } else if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kUpper) { - - int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma1 != 0) { - tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); - tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); - gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; - } - - int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { - gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; - } - - } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kLower) { - - int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; - - if (k_iterations_till_diagonal_mma1 != 0) { - tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); - tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); - gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; - } - - int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; - if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { - gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; - } - - } - - // Construct iterators to A and B operands for Mma1 - typename Mma1::IteratorA iterator_A_mma1( - params.params_A_mma1, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_MxK_mma1); - - typename Mma1::IteratorB iterator_B_mma1( - params.params_B_mma1, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_KxN_mma1); - - // Construct iterators to A and B operands for Mma2 - typename Mma2::IteratorA iterator_A_mma2( - params.params_A_mma2, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_MxK_mma2); - - typename Mma2::IteratorB iterator_B_mma2( - params.params_B_mma2, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_KxN_mma2); - - // Compute threadblock-scoped matrix multiply-add (A x B) or (B x A) - mma1( - gemm_k_iterations_mma1, - accumulators, - iterator_A_mma1, - iterator_B_mma1, - accumulators); - - // Compute threadblock-scoped matrix multiply-add (AT x B) or (B x AT) - mma2( - gemm_k_iterations_mma2, - accumulators, - iterator_A_mma2, - iterator_B_mma2, - accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma1::Shape::kM, - threadblock_tile_offset.n() * Mma1::Shape::kN - ); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC *ptr_C = static_cast(params.ptr_C); - ElementC *ptr_D = static_cast(params.ptr_D); - - // - // Fetch pointers based on mode. - // - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - if (params.mode == GemmUniversalMode::kGemm) { - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - } - else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kBatched) { - ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - ptr_C, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - - __threadfence(); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue( - output_op, - iterator_D, - accumulators, - iterator_C); - - // - // Release the semaphore - // - - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp deleted file mode 100644 index d78bc4b056c61e9cc27f6e17a578d631b62aeb4e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ /dev/null @@ -1,423 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -/*! \file - \brief Utilities for selecting default tile schedulers -*/ - -#include "cutlass/arch/arch.h" -#include "cutlass/detail/dependent_false.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm { - -// -// Tags for specifying tile schedulers -// - -struct PersistentScheduler { }; - -struct StreamKScheduler { }; - -struct GroupScheduler { }; // Only used for Grouped GEMMs - -struct DynamicPersistentScheduler { }; - -struct StaticPersistentScheduler { }; - -} // namespace cutlass::gemm -//////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp" - -#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" -#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp" -#include "cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp" -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel::detail { - -// -// Selectors mapping tile scheduler tag and arch tag to a tile scheduler class -// - -template < - class TileSchedulerTag, - class ArchTag, - class TileShape, - class ClusterShape - , uint32_t SchedulerPipelineStageCount = 2 - , class ProblemShapeType = void -> -struct TileSchedulerSelector { - static_assert(cutlass::detail::dependent_false, - "Could not select a tile scheduler for given parameters."); -}; - -template < - class ArchTag, - class TileShape, - class ClusterShape - , uint32_t SchedulerPipelineStageCount -> -struct TileSchedulerSelector< - PersistentScheduler, - ArchTag, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - > { - using Scheduler = PersistentTileSchedulerSm90; -}; - -// Default (void) for Sm90 maps to PersistentTileSchedulerSm90 -template < - class ArchTag, - class TileShape, - class ClusterShape - , uint32_t SchedulerPipelineStageCount -> -struct TileSchedulerSelector< - void, - ArchTag, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - > { - using Scheduler = typename TileSchedulerSelector< - PersistentScheduler, - ArchTag, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - >::Scheduler; -}; - -template < - class TileShape, - class ClusterShape - , uint32_t SchedulerPipelineStageCount -> -struct TileSchedulerSelector< - StreamKScheduler, - arch::Sm90, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - > { - using Scheduler = PersistentTileSchedulerSm90StreamK; -}; - -template < - class ArchTag, - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount -> -struct TileSchedulerSelector< - StaticPersistentScheduler, - ArchTag, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - > { - using Scheduler = PersistentTileSchedulerSm90; -}; - -template < - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount, - class GroupProblemShape -> -struct TileSchedulerSelector< - GroupScheduler, - arch::Sm90, - TileShape, - ClusterShape - , SchedulerPipelineStageCount - , GroupProblemShape - > { - using Scheduler = PersistentTileSchedulerSm90Group; -}; - -template -struct TileSchedulerSelector< - PersistentScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// Ptr-Array kernel may provide a specialized ArrayProblemShape type -template -struct TileSchedulerSelector< - PersistentScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - ProblemShape> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// Default (void) for Sm100 maps to PersistentTileSchedulerSm100 -template -struct TileSchedulerSelector< - void, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount - >; -}; - -// Default (void) for Sm100 maps to PersistentTileSchedulerSm100 -// Ptr-Array kernel may provide a specialized ArrayProblemShape type -template -struct TileSchedulerSelector< - void, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - ProblemShape> { - using Scheduler = typename TileSchedulerSelector< - PersistentScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount>::Scheduler; -}; - -// SM100 Group tile scheduler -template < - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount, - class GroupProblemShape -> -struct TileSchedulerSelector< - GroupScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - GroupProblemShape - > { - using Scheduler = PersistentTileSchedulerSm100Group; -}; - -// SM100 stream-K scheduler -template -struct TileSchedulerSelector< - StreamKScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100StreamK< - TileShape, - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// SM100 dynamic tile scheduler -template -struct TileSchedulerSelector< - DynamicPersistentScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount>; -}; - -template < - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount -> -struct TileSchedulerSelector< - StaticPersistentScheduler, - arch::Sm100, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = StaticPersistentTileScheduler100; -}; - -template -struct TileSchedulerSelector< - PersistentScheduler, - arch::Sm103, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// Ptr-Array kernel may provide a specialized ArrayProblemShape type -template -struct TileSchedulerSelector< - PersistentScheduler, - arch::Sm103, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - ProblemShape> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// SM103 Group tile scheduler -template < - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount, - class GroupProblemShape -> -struct TileSchedulerSelector< - GroupScheduler, - arch::Sm103, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - GroupProblemShape - > { - using Scheduler = PersistentTileSchedulerSm100Group; -}; - -template -struct TileSchedulerSelector< - StreamKScheduler, - arch::Sm103, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100StreamK< - TileShape, - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// Default (void) for Sm120 maps to PersistentTileSchedulerSm100 -template -struct TileSchedulerSelector< - void, - arch::Sm120, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100< - ClusterShape, - SchedulerPipelineStageCount - >; -}; - -// PersistentScheduler for Sm120 maps to PersistentTileSchedulerSm100 -template -struct TileSchedulerSelector< - PersistentScheduler, - arch::Sm120, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100; -}; - - -// StreamKScheduler for Sm120 maps to PersistentTileSchedulerSm100StreamK -template -struct TileSchedulerSelector< - StreamKScheduler, - arch::Sm120, - TileShape, - ClusterShape, - SchedulerPipelineStageCount> { - using Scheduler = PersistentTileSchedulerSm100StreamK< - TileShape, - ClusterShape, - SchedulerPipelineStageCount>; -}; - -// SM120 Group tile scheduler -template < - class TileShape, - class ClusterShape, - uint32_t SchedulerPipelineStageCount, - class GroupProblemShape -> -struct TileSchedulerSelector< - GroupScheduler, - arch::Sm120, - TileShape, - ClusterShape, - SchedulerPipelineStageCount, - GroupProblemShape - > { - using Scheduler = PersistentTileSchedulerSm90Group; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel::detail - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp deleted file mode 100644 index b1d192c13a45dff4c0082ab8610e6b94dca13996..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_detail.hpp +++ /dev/null @@ -1,88 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -namespace cutlass::gemm::kernel::detail { - -//////////////////////////////////////////////////////////////////////////////// - -enum class RasterOrder { - AlongM, - AlongN -}; - -enum class RasterOrderOptions { - Heuristic, - AlongM, - AlongN -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Strategies for computing reductions between CTAs computing portions of a given output tile -enum class ReductionMode { - // Participating CTAs perform reduction in a turnstile fashion in order of the K extent - // covered by each CTA. This requires a lock to be held exclusively by the CTA that is - // currently accumulating. - // - // Turnstile accumulation ensures deterministic numeric behavior when using this mode. - Deterministic, - - // Participating CTAs perform reduction atomically to the same workspace (mostly) without locking. - // Locks are used only to wait for the first CTA to write its partial values (to initialize the - // workspace), and for all but the final CTA to have accumulated (so that the final CTA can load - // the accumulated value and accumulate it into registers on top of which the epilogue will - // be performed). - // - // Due to the nondeterminsitic ordering of accumulation, deterministic numeric behavior cannot - // be guaranteed with this mode (e.g., floating-point rounding error will depend on the order - // of accumulation) - Nondeterministic -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Strategies for decomposing the problem -enum class DecompositionMode { - // Use a heuristic to determine whether data-parallel, split-K, or stream-K decomposition should be performed - Heuristic, - // Force a data-parallel decomposition - DataParallel, - // Force a split-K decomposition. This should be paired with setting the `splits` parameter - SplitK, - // Force a stream-K decomposition - StreamK -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel::detail diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h deleted file mode 100644 index 96037b121470b8d0c841dd876f1c4802ba1afd52..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ /dev/null @@ -1,2609 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -/*! \file - \brief Parameters structures for persistent tile schedulers -*/ - -#include "cutlass/coord.h" -#include "cutlass/kernel_hardware_info.h" -#include "cutlass/workspace.h" -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm_coord.h" -#include "cutlass/gemm/kernel/tile_scheduler_detail.hpp" -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { -namespace detail { - -//////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -static uint32_t -get_max_cta_occupancy(int max_sm_per_gpc, GemmCoord cluster_shape, int sm_count) { - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; - // Suppose max_sm_per_gpc = 20, cluster_size = 8, sm_count = 148 - // min_num_gpc = 148 / 20 = 7 - // max_cta_occupancy_per_gpc = 20 - (20 % 8) = 16 - // cta_per_device = 7 * 16 = 112 - // num_gpc_residual = 148 % 20 = 8 - // max_cta_occupancy_per_residual_gpc = 8 - (8 % 8) = 8 - // cta_per_device += 8 = 120 - // cta_per_device = 120 < 148 ? 148 : 120 = 148 - - // The calculation below allows for larger grid size launch for different GPUs. - int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); - cta_per_device += max_cta_occupancy_per_residual_gpc; - - cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; - return cta_per_device; -} - -//////////////////////////////////////////////////////////////////////////////// - -// -// Parameters for SM90 tile schedulers -// - -// Parameters for SM90 persistent tile scheduler -struct PersistentTileSchedulerSm90Params { - using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; - using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; - - FastDivmodU64Pow2 divmod_cluster_shape_major_{}; - FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; - FastDivmodU64 divmod_batch_{}; - FastDivmodU64 divmod_cluster_blk_major_{}; - - uint64_t blocks_per_problem_ = 0; - int32_t log_swizzle_size_ = 0; - RasterOrder raster_order_ = RasterOrder::AlongN; - - uint32_t problem_tiles_m_ = 0; - uint32_t problem_tiles_n_ = 0; - uint32_t problem_tiles_l_ = 0; - uint32_t cluster_shape_m_ = 0; - uint32_t cluster_shape_n_ = 0; - - // Initializes members. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - void - initialize( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - return initialize( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option - ); - } - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - - CUTLASS_UNUSED(hw_info); - - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - problem_tiles_m_ = problem_blocks_m / cluster_shape.m(); - problem_tiles_n_ = problem_blocks_n / cluster_shape.n(); - problem_tiles_l_ = problem_blocks.z; - cluster_shape_m_ = cluster_shape.m(); - cluster_shape_n_ = cluster_shape.n(); - - RasterOrder raster_order = get_rasterization_order( - problem_blocks_m, - problem_blocks_n, - raster_order_option - ); - - // - // Set members - // - - blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z; - log_swizzle_size_ = log_swizzle_size; - raster_order_ = raster_order; - divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); - - if (raster_order == RasterOrder::AlongN) { - divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); - divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m()); - divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_n / cluster_shape.n()); - } - else { - divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); - divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); - divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_m / cluster_shape.m()); - } - } - - // Given the inputs, computes the physical grid we should launch. - // This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - BatchedGemmCoord problem_shape, - GemmCoord cta_shape, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool truncate_by_problem_size=true, - bool bypass_sm90_occupancy_calculation=false - ) { - - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - return get_grid_shape( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option, - truncate_by_problem_size, - bypass_sm90_occupancy_calculation - ); - } - - // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool truncate_by_problem_size=true, - bool bypass_sm90_occupancy_calculation=false - ) { - - int const sm_count = hw_info.sm_count; - int const max_active_clusters = hw_info.max_active_clusters; - - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; - - RasterOrder raster_order = get_rasterization_order( - problem_blocks_m, - problem_blocks_n, - raster_order_option - ); - - dim3 launch_grid; - - if (raster_order == RasterOrder::AlongN) { - launch_grid = dim3(cluster_shape.m(), 1, 1); - } - else { - launch_grid = dim3(1, cluster_shape.n(), 1); - } - - auto possibly_truncate = [&](int x, int y) { - if (truncate_by_problem_size) { - return platform::min(x, y); - } - else { - return x; - } - }; - - // The else path is generic, however, we can avoid some divs if we know cluster size is 1 - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - if (cluster_size == 1) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); - } - else { - launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); - } - } - // In case the maximum number of clusters that could co-exist on the target device is - // already calculated using cudaOccupancyMaxActiveClusters - else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate( - max_active_clusters * cluster_shape.n(), - problem_blocks_total / cluster_shape.m()); - - } - else { - launch_grid.x = possibly_truncate( - max_active_clusters * cluster_shape.m(), - problem_blocks_total / cluster_shape.n()); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - else { - int cta_per_device = sm_count; - if (!bypass_sm90_occupancy_calculation) { - /* - * Optimal grid size calculation is based on - * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU - * Hence, maximum SMs per GPC = 18 - */ - constexpr int max_sm_per_gpc = 18; - cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); - } - - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate( - cta_per_device / cluster_shape.m(), - problem_blocks_total / cluster_shape.m()); - } - else { - launch_grid.x = possibly_truncate( - cta_per_device / cluster_shape.n(), - problem_blocks_total / cluster_shape.n()); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using heuristics = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - return launch_grid; - } - - CUTLASS_HOST_DEVICE - static int32_t - get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); - if (max_swizzle_size >= 8 && min_cta_dim >= 6) { - return 3; - } - else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { - return 2; - } - else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { - return 1; - } - else { - return 0; - } - } - - CUTLASS_HOST_DEVICE - static RasterOrder - get_rasterization_order( - uint32_t tiles_m, - uint32_t tiles_n, - RasterOrderOptions raster_order_option - ) { - - if (raster_order_option == RasterOrderOptions::Heuristic) { - if (tiles_n > tiles_m) { - return RasterOrder::AlongM; - } - else { - return RasterOrder::AlongN; - } - } - else { - switch (raster_order_option) { - case RasterOrderOptions::AlongN: - return RasterOrder::AlongN; - break; - default: - return RasterOrder::AlongM; - } - } - } - - // Get the number of CTA tiles in this problem. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl(BatchedGemmCoord problem_shape, GemmCoord cta_shape, GemmCoord cluster_shape) { - auto cta_m = (problem_shape.m() + cta_shape.m() - 1) / cta_shape.m(); - auto cta_n = (problem_shape.n() + cta_shape.n() - 1) / cta_shape.n(); - - return get_tiled_cta_shape_mnl(problem_shape, cluster_shape, cta_m, cta_n); - } - - // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl(BatchedGemmCoord problem_shape, GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) { - - // Round up to nearest multiple of cluster dim along each mode - auto problem_blocks_m = ((cta_m + cluster_shape.m() - 1) / cluster_shape.m()) * cluster_shape.m(); - auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n(); - - return { - static_cast(problem_blocks_m), - static_cast(problem_blocks_n), - static_cast(problem_shape.batch()) - }; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Parameters for SM90 persistent stream-K scheduler -struct PersistentTileSchedulerSm90StreamKParams { - using ReductionMode = cutlass::gemm::kernel::detail::ReductionMode; - using DecompositionMode = cutlass::gemm::kernel::detail::DecompositionMode; - - - using UnderlyingParams = PersistentTileSchedulerSm90Params; - using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; - using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; - - // Cluster dimensions are typically always a power of 2, so use - // the power-of-two variants of FastDivmod for these. - FastDivmodU64Pow2 divmod_cluster_shape_major_{}; - FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; - - FastDivmodU64 divmod_batch_{}; - FastDivmodU64 divmod_cluster_blk_major_{}; - - // Total number of cluster-sized output tiles (i.e., not including any - // splitting factors). This is primarily used for split-K decompositions, - // and may be overridden in other decompositions. - FastDivmodU64 divmod_clusters_mnl_{}; - - // We divide up the number of stream-K tiles amongst G groups of stream-K units. - // The stream-K units within a group collaborate to compute over the `sk_tiles / G` - // tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of - // partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents - // of tiles that would be assigned in the same wave according to the rasterization order of the - // data-parallel formulation of the problem. - FastDivmodU64 divmod_sk_groups_{}; - - // Number of stream-K units in each group - FastDivmodU64 divmod_sk_units_per_group_{}; - - uint64_t units_per_problem_ = 0; - FastDivmod divmod_tiles_per_output_tile_{}; - int32_t log_swizzle_size_ = 0; - RasterOrder raster_order_ = RasterOrder::AlongN; - - // The splitting factor to be used in a split-K decomposition of the problem. - // If this is set to a value greater than 1, stream-K decomposition logic - // is bypassed in favor of a split-K decomposition. - FastDivmod divmod_splits_{}; - - // Number of stream-K or split-K work units that compute an extra k iteration. - // This is done to handle residuals in dividing up the k iteration space. - // For stream-K, since the actual assignment of work to stream-K units will be done - // at the granularity of a cluster, we store only the number of big clusters. - uint32_t big_units_ = 0; - - // The number of groups of stream-K units that will process an extra stream-K tile cluster. - uint32_t big_groups_ = 0; - - // Workspace for holding partial accumulators to be reduced across stream-K/split-K units - void* reduction_workspace_ = nullptr; - - // Number of tiles covered by stream-K work units - uint32_t sk_tiles_ = 0; - - // Number of work units computing stream-K tiles - uint32_t sk_units_ = 0; - - // Number of tiled k iterations computed by each stream-K work unit. This - // can potentially cover more than one output tile. - FastDivmod divmod_k_tiles_per_sk_unit_{}; - // Number of tiled k iterations computed by each "big" stream-K units, which - // processes one more K chunk than a "normal" stream-K unit. - FastDivmod divmod_k_tiles_per_sk_big_unit_{}; - - // Strategy to use when reducing between collaborating CTAs - ReductionMode reduction_mode_ = ReductionMode::Deterministic; - - // The number of sub blocks in the kernel epilogue - FastDivmodU64 divmod_epilogue_subtile_{}; - - // The number of blocks that launched for doing separate reduction - uint32_t separate_reduction_units_ = 0; - - // Minimum number of k tiles that can be assigned to a stream-K unit - static constexpr uint32_t min_iters_per_sk_unit_ = 8u; - - // Maximum number of groups of stream-K units - static constexpr uint32_t max_sk_groups_ = 8u; - - // ktile start from even for each cta - uint32_t ktile_start_alignment_count_ { 1u }; - - // Divides dividend by the cluster size - CUTLASS_HOST_DEVICE - uint64_t - div_cluster_size(uint64_t dividend) const { - // Use each underlying fast divmod rather than performing integer division - // by the multiplication of major.divisor * minor.divisor - return divmod_cluster_shape_minor_.divide( - divmod_cluster_shape_major_.divide(dividend) - ); - } - - - // Divides dividend by the cluster size in the M dimension - CUTLASS_HOST_DEVICE - uint64_t - truncate_to_cluster_size_m(uint64_t dividend) const { - if (raster_order_ == RasterOrder::AlongN) { - return divmod_cluster_shape_minor_.divide(dividend) * divmod_cluster_shape_minor_.divisor; - } - else { - return divmod_cluster_shape_major_.divide(dividend) * divmod_cluster_shape_major_.divisor; - } - } - - // Divides dividend by the cluster size in the N dimension - CUTLASS_HOST_DEVICE - uint64_t - truncate_to_cluster_size_n(uint64_t dividend) const { - if (raster_order_ == RasterOrder::AlongM) { - return divmod_cluster_shape_minor_.divide(dividend) * divmod_cluster_shape_minor_.divisor; - } - else { - return divmod_cluster_shape_major_.divide(dividend) * divmod_cluster_shape_major_.divisor; - } - } - - - CUTLASS_HOST_DEVICE - uint64_t - get_cluster_size() const { - return divmod_cluster_shape_minor_.divisor * divmod_cluster_shape_major_.divisor; - } - - // Returns whether the kernel uses separate reduction - CUTLASS_HOST_DEVICE - bool - requires_separate_reduction() const { - return separate_reduction_units_ > 0; - } - - // Returns the maximum number of peers that can collaborate on a given output tile - CUTLASS_HOST_DEVICE - static uint32_t - max_peers_per_tile(uint64_t sk_units, uint64_t sk_tiles) { - // When we can divide up our SK units to SK tiles evenly, the number of peers - // per SK tile is exactly (sk_units_ / sk_tiles_). In cases where this division - // is not exact, some tiles will need to be covered by additional SK units. Because - // the extra work can occur at both the beginning and the end of the SK tile, at - // most 2 extra peers will be needed. - return static_cast(sk_units / sk_tiles + 2); - } - - // Initializes members. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - void - initialize( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - ReductionMode reduction_mode, - DecompositionMode decomposition_mode, - void* workspace, - const uint32_t epilogue_subtile = 1u, - uint32_t ktile_start_alignment_count = 1u, - bool bypass_sm90_occupancy_calculation=false - ) { - dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl( - problem_shape, tile_shape, cluster_shape); - - // Number of k tiles in each output tile - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - initialize( - problem_blocks, - k_tiles_per_output_tile, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - reduction_mode, - decomposition_mode, - workspace, - epilogue_subtile, - ktile_start_alignment_count, - bypass_sm90_occupancy_calculation - ); - } - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - ReductionMode reduction_mode, - DecompositionMode decomposition_mode, - void* workspace, - const uint32_t epilogue_subtile = 1, - uint32_t ktile_start_alignment_count = 1u, - bool bypass_sm90_occupancy_calculation=false - ) { - - #if !defined(__CUDACC_RTC__) - if (hw_info.sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - } - #endif // !defined(__CUDACC_RTC__) - - ktile_start_alignment_count_ = ktile_start_alignment_count; - UnderlyingParams underlying_params; - underlying_params.initialize( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle, - raster_order_option - ); - - // Set basic parameters that not affected by any heuristics in advance. - set_params_base(underlying_params, workspace); - - // Call for internal streamk heuristic to setup streamk related params - stream_k_heuristic( - underlying_params, - problem_blocks, - k_tiles_per_output_tile, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - epilogue_subtile, - ktile_start_alignment_count, - bypass_sm90_occupancy_calculation - ); - } - - // max_sk_groups_ unless this extends beyond the extent of the dimension over - // which the problem is rasterized. For example, if the tiled problem shape - // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, - // and we rasterize along the M dimension, we choose 4 groups, rather than 8. - // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). - uint32_t calculate_groups( - UnderlyingParams underlying_params, - ReductionMode reduction_mode, - uint32_t problem_blocks_m, - uint32_t problem_blocks_n, - GemmCoord cluster_shape, - uint64_t cluster_size, - uint32_t sk_tiles, - uint64_t sk_cluster_tiles, - uint64_t sk_units, - uint32_t k_tiles_per_output_tile, - bool do_separate_reduction) { - - uint32_t max_groups_problem; - if (underlying_params.raster_order_ == RasterOrder::AlongM) { - max_groups_problem = problem_blocks_m / cluster_shape.m(); - } - else { - max_groups_problem = problem_blocks_n / cluster_shape.n(); - } - - // Select the number of groups that will be use. We start with the maximum - // number of potential groups, and iterate down looking for a group size that - // evenly divides the stream-K units and tiles, and for which the resulting - // number of K tiles per stream-K unit remains above min_iters_per_sk_unit_ - - uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_)); - // Grouping is disabled when separate reduction is used because grouping is primarily an attempt - // to improve L2 locality, and L2-locality optimizations are unnecessary when the the kernel - // is a single wave (which is the case for separate reduction). - if ( - do_separate_reduction - ) { - groups = 1; - } - - uint32_t fallback_groups = 0; - auto sk_cluster_units = sk_units / cluster_size; - - auto sk_splits_too_small = [&](uint32_t g) { - // Check whether the number of K tiles computed per stream-K unit is less - // than min_iters_per_sk_unit_ - auto total_sk_cluster_tiles = (sk_cluster_tiles / g) * cluster_size; - auto total_sk_k_tiles = total_sk_cluster_tiles * k_tiles_per_output_tile; - auto k_tiles_per_sk_unit = total_sk_k_tiles / (sk_units / g); - return k_tiles_per_sk_unit < min_iters_per_sk_unit_; - }; - - auto is_ideal_grouping = [&](uint32_t g) { - // An ideal grouping will evenly divide stream-K clusters, evenly divide - // stream-K tiles, and not result in stream-K splits that are too small. - return (sk_cluster_units % g == 0) && (sk_cluster_tiles % g == 0) && !sk_splits_too_small(g); - }; - - auto is_valid_grouping = [&](uint32_t g) { - // A grouping is valid, but not ideal, if it evenly divides the - // stream-K clusters and does not result in stream-K splits that are - // too small. Such a setting can be used as a fallback option in the - // case that an ideal grouping is not achievable - return sk_cluster_units % g == 0 && !sk_splits_too_small(g); - }; - - while (groups > 1 && !is_ideal_grouping(groups)) { - if (fallback_groups == 0 && is_valid_grouping(groups)) { - // Set fallback groups once in preference for a larger number of groups. - fallback_groups = groups; - } - --groups; - } - - // If groups == 1, we did not find a group count that satisfies all criteria. If we have - // found a fallback group count, use this instead. - if (groups == 1 && fallback_groups > 0) { - groups = fallback_groups; - } - return groups; - } - - // Stream-K kernel use below function to set stream-K feature related parameters to choose - // optimal/customized decomposition mode. - void stream_k_heuristic( - UnderlyingParams underlying_params, - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - const uint32_t epilogue_subtile = 1, - uint32_t ktile_start_alignment_count = 1u, - bool bypass_sm90_occupancy_calculation=false) { - uint32_t groups = 0; - uint32_t sk_tiles = 0; - uint64_t sk_units = 0; - uint64_t cluster_size = 0; - uint64_t dp_units = 0; - uint64_t k_tiles_per_group = 0; - uint64_t k_tiles_per_sk_unit = 0; - uint64_t sk_big_groups = 0; - uint32_t sk_splits = 1; - // Self calculated optimal heuristic mode - DecompositionMode heuristic_mode = - select_decomposition_mode( - groups, - sk_tiles, - sk_units, - cluster_size, - dp_units, - k_tiles_per_group, - k_tiles_per_sk_unit, - sk_big_groups, - sk_splits, - underlying_params, - problem_blocks, - k_tiles_per_output_tile, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - epilogue_subtile, - ktile_start_alignment_count, - bypass_sm90_occupancy_calculation - ); - - // Given heuristic_mode returned from the heuristic() method, set params fields. - // Here, we decouple the params that have no relation with - // decomposition mode from the params that are decided within heuristic(). - set_params( - heuristic_mode, - groups, - sk_tiles, - sk_units, - cluster_size, - dp_units, - k_tiles_per_group, - k_tiles_per_sk_unit, - sk_big_groups, - sk_splits, - underlying_params, - problem_blocks, - k_tiles_per_output_tile, - cluster_shape, - splits, - epilogue_subtile, - reduction_mode, - ktile_start_alignment_count - ); - } - - // Return the optimal decomposition result by heuristic. - DecompositionMode select_decomposition_mode( - uint32_t &groups, - uint32_t &sk_tiles, - uint64_t &sk_units, - uint64_t &cluster_size, - uint64_t &dp_units, - uint64_t &k_tiles_per_group, - uint64_t &k_tiles_per_sk_unit, - uint64_t &sk_big_groups, - uint32_t &sk_splits, - UnderlyingParams underlying_params, - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t epilogue_subtile, - uint32_t ktile_start_alignment_count, - bool bypass_sm90_occupancy_calculation=false - ) { - - // Get block numbers in m, n and l dimensions - if (decomposition_mode == DecompositionMode::SplitK || - (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { - // Short circuit to basic split-K decomposition - uint32_t adapted_splits = adjust_split_count( - splits, hw_info.sm_count, k_tiles_per_output_tile - , ktile_start_alignment_count - ); - sk_splits = adapted_splits; - return DecompositionMode::SplitK; - } - else { - // Calculate the maximum number of blocks from clusters of shape cluster_shape that we - // can fit within sm_count SMs. - // Get block numbers in m, n and l dimensions - auto problem_blocks_l = problem_blocks.z; - auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); - uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; - dim3 grid = get_grid_shape( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle, - raster_order_option, - bypass_sm90_occupancy_calculation - ); - uint64_t ctas_per_wave = grid.x * grid.y; - cluster_size = cluster_shape.m() * cluster_shape.n(); - uint64_t ctas_per_wave_in_full_clusters = (ctas_per_wave / cluster_size) * cluster_size; - - // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. - sk_tiles = get_num_sk_tiles( - output_tiles, - ctas_per_wave, - cluster_size, - k_tiles_per_output_tile, - decomposition_mode, - ctas_per_wave_in_full_clusters - ); - uint64_t dp_tiles = output_tiles - sk_tiles; - // Calculate the number of work units covering the data-parallel and stream-K tiles. - // A "work unit" is a single index in the linearized ID space used by the scheduler. - // We distinguish it from a "block," which is typically tied to a hardware unit - // (e.g., the callers into this scheduler will be persistent thread blocks). - // A work unit can encompass multiple output tiles worth of work (as will be the - // case for stream-K blocks). - // Since splitting is not required for data-parallel tiles, only one data-parallel unit - // is needed per data-parallel tile. - dp_units = dp_tiles; - - uint64_t ctas_per_sk_wave = ctas_per_wave; - ctas_per_sk_wave = ctas_per_wave_in_full_clusters; - sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); - - if (decomposition_mode == DecompositionMode::DataParallel || - (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || - sk_units == 0) { - // Short circuit to basic data-parallel decomposition - return DecompositionMode::DataParallel; - } - else { - bool do_separate_reduction = should_perform_separate_reduction( - epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); - - uint64_t sk_cluster_tiles = sk_tiles / cluster_size; - - groups = calculate_groups(underlying_params, reduction_mode, problem_blocks_m, problem_blocks_n, cluster_shape, - cluster_size, sk_tiles, sk_cluster_tiles, sk_units, k_tiles_per_output_tile, do_separate_reduction); - - auto sk_units_per_group = sk_units / groups; - - // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: - // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) - // Both total_tiles and sm_count are multiples of cluster size due to padding added - // prior to kernel launch. - uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; - uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; - - // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which - // are stream-K units within a group that process an extra K chunk. - sk_big_groups = sk_cluster_tiles % groups; - - k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; - - // Number of k tiles computed per stream-K unit - k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; - - DecompositionMode heuristic_mode; - if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { - // If the number of stream-K units is a multiple of the number of stream-K tiles, then - // the problem can leverage a basic split-K decomposition for the stream-K tiles. - // This case happens when separate reduction is disable. - sk_splits = static_cast(sk_units / sk_tiles); - heuristic_mode = DecompositionMode::SplitK; - } - else { - // Rest scenario is streamk - heuristic_mode = DecompositionMode::StreamK; - } - // Refresh heuristic_mode using analytical model before choosing streamk/separate_reduction decomposition, - // ideally it's to get the final decomposition more accuracy. Comment it as it is place holder at this moment. - #if 0 - uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); - analytical_model(heuristic_mode, k_tiles_per_output_tile, k_tiles_per_sk_unit, - sk_splits, epilogue_subtile, total_waves); - #endif - return heuristic_mode; - } - } - } - - // Given decomposition mode output from heuristic, set all fields of params. - void set_params( - DecompositionMode heuristic_mode, - uint32_t groups, - uint32_t sk_tiles, - uint64_t sk_units, - uint64_t cluster_size, - uint64_t dp_units, - uint64_t k_tiles_per_group, - uint64_t k_tiles_per_sk_unit, - uint64_t sk_big_groups, - uint32_t sk_splits, - UnderlyingParams underlying_params, - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord cluster_shape, - uint32_t splits, - uint32_t epilogue_subtile, - ReductionMode reduction_mode - , uint32_t ktile_start_alignment_count - ) { - // The highest priority when customers set as splitk mode, may set - // with a adapted splits value rather than the original splits - // even it does not make sense - if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) { - set_params_basic( - underlying_params, - problem_blocks, - cluster_shape, - sk_splits, // split-k set by customers - k_tiles_per_output_tile, - reduction_mode - ); - } - else if (heuristic_mode == DecompositionMode::DataParallel) { - set_params_basic( - underlying_params, - problem_blocks, - cluster_shape, - 1, // fast path to fall back to the mode without any split scheme - k_tiles_per_output_tile, - reduction_mode - ); - } - else if (heuristic_mode == DecompositionMode::SplitK) { - set_params_basic( - underlying_params, - problem_blocks, - cluster_shape, - sk_splits, // splits calculated by heuristic - k_tiles_per_output_tile, - reduction_mode - ); - } - else { - // streamk - set_params_stream_k( - underlying_params, - k_tiles_per_output_tile, - groups, - sk_tiles, - sk_units, - cluster_size, - dp_units, - k_tiles_per_group, - k_tiles_per_sk_unit, - sk_big_groups, - reduction_mode, - 1, /*epilogue_subtile*/ - 0 /*reduction_units*/ - ); - } - } - - // Given the inputs, computes the physical grid we should launch. - // This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE - static dim3 - get_grid_shape( - BatchedGemmCoord problem_shape, - GemmCoord cta_shape, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool bypass_sm90_occupancy_calculation=false - ) { - - dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - - return get_grid_shape( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option, - bypass_sm90_occupancy_calculation - ); - } - - // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE - static dim3 - get_grid_shape( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool bypass_sm90_occupancy_calculation=false - ) { - - // Call into the underlying get_grid_shape method, but do not allow the grid shape returned - // to be truncated based on the number of output tiles in the problem. - return UnderlyingParams::get_grid_shape( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option, - /* truncate_by_problem_size = */false, - bypass_sm90_occupancy_calculation - ); - } - - // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total - // output tiles on a device with `ctas_per_wave` CTAs in each wave. - static uint32_t - get_num_sk_tiles( - uint64_t output_tiles, - uint64_t ctas_per_wave, - uint64_t cluster_size, - uint32_t k_tiles_per_output_tile, - DecompositionMode decomposition_mode - , uint64_t ctas_per_wave_in_full_clusters - ) { - uint32_t full_waves = static_cast(output_tiles / ctas_per_wave); - uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); - - if (decomposition_mode == DecompositionMode::DataParallel || - decomposition_mode == DecompositionMode::SplitK) { - return 0; - } - - // If there is wave quantization, assign the first two waves worth of tiles to be - // covered by stream-K work and the remainder to be data-parallel. Since we know - // that full_waves == total_waves - 1 in this case, the number of data-parallel - // waves is simply full_waves-1 (unless full_waves == 0). - uint32_t dp_waves = full_waves > 1 ? full_waves - 1 : 0; - uint64_t dp_tiles = dp_waves * ctas_per_wave; - uint64_t sk_tiles = output_tiles - dp_tiles; - - if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { - // All tiles will be data-parallel tiles if there is either no quantization - // or if there is no work to be split. - return 0; - } - - // - // The final wave is not full. Perform some stream-K work. - // - if (decomposition_mode == DecompositionMode::Heuristic) { - // Rudimentary heuristic: prefer data-parallel decomposition if we have more than - // one wave and the tail wave is more than half full. This is subject to change. - uint64_t tail_tiles = output_tiles - (full_waves * ctas_per_wave); - if (2 * tail_tiles >= ctas_per_wave) { - return 0; - } - } - // Ensure that the number of SK tiles is divisible by cluster size so that it can be evenly - // divided among SK clusters. - sk_tiles = (sk_tiles / cluster_size) * cluster_size; - - return static_cast(sk_tiles); - } - - CUTLASS_HOST_DEVICE - static uint64_t - get_num_sk_units(GemmCoord cluster_shape, uint64_t ctas_per_sk_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) { - // If there are stream-K tiles to compute and a sufficiently large number of k iterations - // across them, they will be covered by a single wave of persistent threadblocks. Thus, there - // will be as many work units as there are threadblocks in a single wave. - // - // When the total k iterations across stream-K tiles is too small to justify distributing - // across an entire wave of blocks, we instead distribute the iterations over a smaller - // set of blocks. - - // Calculate the number of stream-K units that would be needed if each stream-K unit - // computed the minimum allowable k iterations. Truncate this to be in units of clusters. - - // Number of k iterations computed by the stream-K units as a whole - uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles; - - // Calculate the number of stream-K units that would be needed if each stream-K unit - // computed the minimum allowable k iterations. Truncate this to be in units of clusters. - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_); - min_sized_sk_units = (min_sized_sk_units / cluster_size) * cluster_size; - - uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units); - return sk_units; - } - - // Calculates the size of the workspace needed for holding reduction barriers - CUTLASS_HOST_DEVICE - static size_t - get_barrier_workspace_size(uint64_t num_tiles, uint32_t mma_warp_groups, uint32_t barrier_bits) { - size_t workspace_bits = num_tiles * static_cast(mma_warp_groups) * static_cast(barrier_bits); - return round_up_to_l2_alignment(bits_to_bytes(workspace_bits)); - } - - // Calculates the size of the workspace needed for holding partial outputs from splits - CUTLASS_HOST_DEVICE - static size_t - get_reduction_workspace_size(uint64_t num_tiles, GemmCoord tile_shape, uint32_t accumulator_bits, uint32_t num_accumulator_mtxs = 1) { - size_t output_tile_size = tile_shape.m() * tile_shape.n(); - size_t workspace_bits = accumulator_bits * output_tile_size * num_tiles * num_accumulator_mtxs; - return round_up_to_l2_alignment(bits_to_bytes(workspace_bits)); - } - - #if !defined(__CUDACC_RTC__) - static void - get_workspace_component_sizes( - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord tile_shape, - GemmCoord cluster_shape, - size_t& barrier_workspace_size, - size_t& reduction_workspace_size, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t mma_warp_groups, - uint32_t barrier_bits, - uint32_t accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - uint32_t ktile_start_alignment_count = 1, - bool bypass_sm90_occupancy_calculation=false) { - - auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle); - problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - problem_blocks.y = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - // Workspace is needed only for output tiles that will be split. Thus, we first determine the number - // of output tiles that will be split, and then calculate the workspace needed to cover these. - uint64_t output_tiles = problem_blocks.x * problem_blocks.y * problem_blocks.z; - - if (decomposition_mode == DecompositionMode::DataParallel) { - barrier_workspace_size = 0; - reduction_workspace_size = 0; - } - else { - KernelHardwareInfo new_hw_info; - new_hw_info.device_id = hw_info.device_id; - new_hw_info.sm_count = hw_info.sm_count; - new_hw_info.max_active_clusters = hw_info.max_active_clusters; - if (new_hw_info.sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - new_hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(new_hw_info.device_id); - } - - dim3 grid = get_grid_shape( - problem_blocks, - cluster_shape, - new_hw_info, - max_swizzle, - raster_order_option, - bypass_sm90_occupancy_calculation - ); - uint64_t ctas_per_wave = grid.x * grid.y; - uint64_t cluster_size = cluster_shape.m() * cluster_shape.n(); - uint64_t ctas_per_wave_in_full_clusters = (ctas_per_wave / cluster_size) * cluster_size; - uint32_t sk_tiles = get_num_sk_tiles( - output_tiles, - ctas_per_wave, - cluster_size, - static_cast(k_tiles_per_output_tile), - decomposition_mode - , ctas_per_wave_in_full_clusters - ); - uint64_t ctas_per_sk_wave = ctas_per_wave; - ctas_per_sk_wave = ctas_per_wave_in_full_clusters; - uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); - uint64_t dp_tiles = output_tiles - sk_tiles; - - if (decomposition_mode == DecompositionMode::SplitK || - (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { - splits = adjust_split_count( - splits, new_hw_info.sm_count, k_tiles_per_output_tile - , ktile_start_alignment_count - ); - } - - bool split_k_required = splits > 1 && (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic); - bool split_k_selected = !split_k_required && - decomposition_mode == DecompositionMode::Heuristic && - sk_units > sk_tiles && - sk_tiles != 0 && - sk_units % sk_tiles == 0; - - if (split_k_required || split_k_selected) { - // Basic split-K variant requires workspace for all output tiles - barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); - } - else { - uint64_t reduction_tiles = sk_tiles; - if ( - should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave) - ) { - // In separate reduction, each peer writes to its own location in scratch space. - // Thus, for separate reduction, we need as many reduction tiles per output tile - // as there are the maximum number of peers that can collaborate on an output tile. - reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); - } - - // Though separate reduction requires a larger reduction workspace, only one barrier - // is needed per output tile. Each peer will increment the barrier by one once the peer has - // written its accumulator to scratch space. The separate reduction unit will only begin - // performing the reduction when the barrier has reached the number of peers for the output tile. - barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); - } - } - } - #endif // !defined(__CUDACC_RTC__) - - // Returns whether the kernel is configured in a manner for which separate reduction should be used - CUTLASS_HOST_DEVICE - static bool - should_perform_separate_reduction(uint32_t, uint64_t, uint64_t, uint64_t, uint64_t) { - // Separate reduction is temporarily disabled, pending fixes - return false; - } - - // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static size_t - get_workspace_size( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t mma_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile, - uint32_t num_accumulator_mtxs, - uint32_t ktile_start_alignment_count = 1) { - - dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - return get_workspace_size( - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - mma_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - ktile_start_alignment_count - ); - } - - // Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static size_t - get_workspace_size( - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t mma_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - uint32_t ktile_start_alignment_count = 1, - bool bypass_sm90_occupancy_calculation=false) { - - size_t barrier_workspace_size = 0; - size_t reduction_workspace_size = 0; - - #if !defined(__CUDACC_RTC__) - get_workspace_component_sizes( - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - barrier_workspace_size, - reduction_workspace_size, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - mma_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - ktile_start_alignment_count, - bypass_sm90_occupancy_calculation - ); - #endif - - return barrier_workspace_size + reduction_workspace_size; - } - - // Initialize the workspace to be used for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t mma_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile, - CudaHostAdapter* cuda_adapter = nullptr, - uint32_t ktile_start_alignment_count = 1) { - - dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - return initialize_workspace( - workspace, - stream, - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - mma_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - 1, - cuda_adapter, - ktile_start_alignment_count - ); - } - - // Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t mma_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter* cuda_adapter = nullptr, - uint32_t ktile_start_alignment_count = 1, - bool bypass_sm90_occupancy_calculation=false) { - - #if !defined(__CUDACC_RTC__) - uint64_t barrier_workspace_size = 0; - uint64_t reduction_workspace_size = 0; - - get_workspace_component_sizes( - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - barrier_workspace_size, - reduction_workspace_size, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - mma_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - ktile_start_alignment_count, - bypass_sm90_occupancy_calculation - ); - - if (barrier_workspace_size > 0) { - if (workspace == nullptr) { - return Status::kErrorWorkspaceNull; - } - - // Only the barrier workspace needs to be cleared for stream-K. - // Barrier workspace follows reduction workspace. - uint8_t* barrier_workspace = reinterpret_cast(workspace) + reduction_workspace_size; - return zero_workspace(static_cast(barrier_workspace), barrier_workspace_size, stream, cuda_adapter); - } - #endif // !defined(__CUDACC_RTC__) - - return Status::kSuccess; - } - - // Set params for basic parameters, which will not affected by different decompositions. - void - set_params_base(UnderlyingParams const& underlying_params, void* reduction_workspace) { - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; - log_swizzle_size_ = underlying_params.log_swizzle_size_; - raster_order_ = underlying_params.raster_order_; - reduction_workspace_ = reduction_workspace; - } - - void - set_params_basic( - UnderlyingParams const& underlying_params, - dim3 problem_blocks, - GemmCoord cluster_shape, - uint32_t splits, - uint32_t k_tiles_per_output_tile, - ReductionMode reduction_mode) { - - auto blocks_l = problem_blocks.z; - auto blocks_m = round_up(problem_blocks.x, - (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); - auto blocks_n = round_up(problem_blocks.y, - (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); - - divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); - divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); - divmod_sk_groups_ = FastDivmodU64(1u); - auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * - underlying_params.divmod_cluster_shape_minor_.divisor; - divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); - divmod_splits_ = FastDivmod(splits); - units_per_problem_ = blocks_m * blocks_n * blocks_l; - big_units_ = k_tiles_per_output_tile % splits; - reduction_mode_ = reduction_mode; - divmod_k_tiles_per_sk_unit_ = FastDivmod(k_tiles_per_output_tile / splits); - divmod_k_tiles_per_sk_big_unit_ = FastDivmod(k_tiles_per_output_tile / splits + 1); - - // No stream-K work is performed for "basic" data-parallel and split-K decompositions - sk_tiles_ = 0; - sk_units_ = 0; - divmod_sk_units_per_group_ = FastDivmodU64(1u); - separate_reduction_units_ = 0; - } - - // Set params for streamk(streamk, separate-reduction included) decomposition. - void - set_params_stream_k( - UnderlyingParams const& underlying_params, - uint32_t k_tiles_per_output_tile, - uint32_t groups, - uint32_t sk_tiles, - uint64_t sk_units, - uint64_t cluster_size, - uint64_t dp_units, - uint64_t k_tiles_per_group, - uint64_t k_tiles_per_sk_unit, - uint64_t sk_big_groups, - ReductionMode reduction_mode, - uint32_t epilogue_subtile, - uint32_t reduction_units) { - // stream-k and separate-reduction decompostions - divmod_batch_ = underlying_params.divmod_batch_; - divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); - divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); - divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); - - // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. - // This setting ensures that the use of this divmod for stream-K decompositions - // is essentially a no-op. - divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); - divmod_splits_ = FastDivmod(1); - units_per_problem_ = static_cast(dp_units + sk_units); - - // Assign big_units_ assuming that group count == 1. This is unused by stream-K - // when group count > 1. - auto big_units_in_ctas = k_tiles_per_group % sk_units; - - // Store big_units in terms of clusters. big_units_in_ctas is guaranteed to be divisible - // by cluster_size because both k_tiles_per_group and k_tiles_per_sk_unit must be a multiple - // of cluster_size. - auto big_units_in_clusters = big_units_in_ctas / cluster_size; - big_units_ = static_cast(big_units_in_clusters); - - big_groups_ = static_cast(sk_big_groups); - sk_tiles_ = sk_tiles; - sk_units_ = static_cast(sk_units); - divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); - divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); - reduction_mode_ = reduction_mode; - divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); - separate_reduction_units_ = reduction_units; - } - - private: - // Round up number of bytes to the nearest multiple of L2 cache line alignment - CUTLASS_HOST_DEVICE - static size_t - round_up_to_l2_alignment(size_t bytes) { - constexpr size_t L2CacheLineSizeBytes = 128u; - return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes; - } - - CUTLASS_HOST_DEVICE - static int adjust_split_count( - int splits, - int sm_count, - uint32_t k_tiles_per_output_tile - , uint32_t ktile_start_alignment_count - ) { - // Don't split by more than the available number of SMs - if (splits > sm_count) { - splits = sm_count; - } - - // Don't split by more than the K tile iterations - if (static_cast(splits) > k_tiles_per_output_tile) { - splits = k_tiles_per_output_tile; - } - - // If k_tiles_per_output_tiles / splits == 1, there will be one k_tile per cta - // and this violate k_tile start from even requirements. Thus we need to - // reduce the number of splits. - if (ktile_start_alignment_count > 1u && - splits > 1 && - k_tiles_per_output_tile / static_cast(splits) == 1) { - splits = k_tiles_per_output_tile / ktile_start_alignment_count; - } - return splits; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -// Parameters for SM90 persistent group scheduler (only used for Grouped Gemms) -template -struct PersistentTileSchedulerSm90GroupParams { - using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; - using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; - - FastDivmodU64Pow2 divmod_cluster_shape_major_{}; - FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; - FastDivmodU64 divmod_cta_shape_m_{}; - FastDivmodU64 divmod_cta_shape_n_{}; - - uint64_t blocks_across_problem_ = 0; - bool pre_processed_problem_shapes = true; - int32_t log_swizzle_size_ = 0; - RasterOrder raster_order_ = RasterOrder::AlongN; - - GroupProblemShape problem_shapes_; - GemmCoord cta_shape_; - GemmCoord cluster_shape_; - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - GroupProblemShape problem_shapes, - GemmCoord cta_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - - CUTLASS_UNUSED(hw_info); - - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - RasterOrder raster_order = get_rasterization_order( - problem_blocks_m, - problem_blocks_n, - raster_order_option - ); - - // - // Set members - // - problem_shapes_ = problem_shapes; - cta_shape_ = cta_shape; - cluster_shape_ = cluster_shape; - - blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z; - pre_processed_problem_shapes = problem_shapes.is_host_problem_shape_available(); - log_swizzle_size_ = log_swizzle_size; - raster_order_ = raster_order; - - if (raster_order == RasterOrder::AlongN) { - divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); - divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m()); - } - else { - divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); - divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); - } - - divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m()); - divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n()); - } - - // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl(GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) { - // Round up to nearest multiple of cluster dim along each mode - auto problem_blocks_m = ((cta_m + cluster_shape.m() - 1) / cluster_shape.m()) * cluster_shape.m(); - auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n(); - - return { - static_cast(cta_m), - static_cast(cta_n), - static_cast(1) // Only a single batch per group is currently supported - }; - } - - // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool truncate_by_problem_size=true) { - - int const sm_count = hw_info.sm_count; - int const max_active_clusters = hw_info.max_active_clusters; - - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; - - RasterOrder raster_order = get_rasterization_order( - problem_blocks_m, - problem_blocks_n, - raster_order_option - ); - - dim3 launch_grid; - - if (raster_order == RasterOrder::AlongN) { - launch_grid = dim3(cluster_shape.m(), 1, 1); - } - else { - launch_grid = dim3(1, cluster_shape.n(), 1); - } - - auto possibly_truncate = [&](int x, int y) { - if (truncate_by_problem_size) { - return platform::min(x, y); - } - else { - return x; - } - }; - - // The else path is generic, however, we can avoid some divs if we know cluster size is 1 - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - if (cluster_size == 1) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); - } - else { - launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); - } - } - // In case the maximum number of clusters that could co-exist on the target device is - // already calculated using cudaOccupancyMaxActiveClusters - else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = max_active_clusters * cluster_shape.n(); - } - else { - launch_grid.x = max_active_clusters * cluster_shape.m(); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - else { - // Optimal grid size calculation is based on - // GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU - // Hence, maximum SMs per GPC = 18 - constexpr int max_sm_per_gpc = 18; - int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); - - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate( - cta_per_device / cluster_shape.m(), - problem_blocks_total / cluster_shape.m()); - } - else { - launch_grid.x = possibly_truncate( - cta_per_device / cluster_shape.n(), - problem_blocks_total / cluster_shape.n()); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using heuristics = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - return launch_grid; - } - - CUTLASS_HOST_DEVICE - static int32_t - get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); - if (max_swizzle_size >= 8 && min_cta_dim >= 6) { - return 3; - } - else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { - return 2; - } - else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { - return 1; - } - else { - return 0; - } - } - - CUTLASS_HOST_DEVICE - static RasterOrder - get_rasterization_order( - uint32_t tiles_m, - uint32_t tiles_n, - RasterOrderOptions raster_order_option - ) { - - if (raster_order_option == RasterOrderOptions::Heuristic) { - if (tiles_n > tiles_m) { - return RasterOrder::AlongM; - } - else { - return RasterOrder::AlongN; - } - } - else { - switch (raster_order_option) { - case RasterOrderOptions::AlongN: - return RasterOrder::AlongN; - break; - default: - return RasterOrder::AlongM; - } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////// - - -// -// Parameters for SM100 tile schedulers -// - -// Parameters for SM100 persistent tile scheduler -struct PersistentTileSchedulerSm100Params { - - using UnderlyingParams = PersistentTileSchedulerSm90Params; - - using RasterOrder = UnderlyingParams::RasterOrder; - using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; - - uint32_t problem_tiles_m_ = 0; - uint32_t problem_tiles_n_ = 0; - uint32_t problem_tiles_l_ = 0; - FastDivmod divmod_cluster_shape_m_{}; - FastDivmod divmod_cluster_shape_n_{}; - FastDivmod divmod_swizzle_size_{}; - RasterOrder raster_order_ = RasterOrder::AlongM; - int32_t log_swizzle_size_ = 0; - // Initializes members. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - void - initialize( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - initialize( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option - ); - } - - void initialize_swizzle( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option) { - - raster_order_ = UnderlyingParams::get_rasterization_order(problem_tiles_m_, problem_tiles_n_, raster_order_option); - if (raster_order_option == RasterOrderOptions::Heuristic && raster_order_ == RasterOrder::AlongN) { - // The current implementation of AlongN rasterization for B100 requires swapping the number of clusters along the - // X and Y dimensions of the grid. However, since the grid Y dimension has a smaller range of allowed values - // than the grid X dimension, we must check whether the swapped grid would exceed the grid Y limit. If the - // swapped grid would exceed this limit, simply rever to AlongM mode. - // - // Overflow in the swapped X dimension is not possible. At worst, there will be ((1 << 16) - 1) clusters - // along the original Y dimension of the grid. Even if the cluster M mode is 16, the new grid X value - // will be at most ((1 << 16) - 1) * 16, which is less than the grid X limit of ((1 << 31) - 1). - uint32_t new_grid_y = problem_tiles_m_ * static_cast(cluster_shape.n()); - - if (new_grid_y > (1 << 16) - 1) { - raster_order_ = RasterOrder::AlongM; - } - } - - if (max_swizzle_size <= 1) { - // Set divisors directly to be zero to mark as unused - divmod_swizzle_size_.divisor = 0; - } - else { - divmod_swizzle_size_ = FastDivmod(max_swizzle_size); - } - } - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - - // Cluster counters in m, n and l dimensions of the problem tiles - problem_tiles_m_ = problem_blocks.x / cluster_shape.m(); - problem_tiles_n_ = problem_blocks.y / cluster_shape.n(); - problem_tiles_l_ = problem_blocks.z; - divmod_cluster_shape_m_ = FastDivmod(cluster_shape.m()); - divmod_cluster_shape_n_ = FastDivmod(cluster_shape.n()); - - initialize_swizzle(problem_blocks, cluster_shape, hw_info, max_swizzle_size, raster_order_option); - } - - // Given the inputs, computes the physical grid we should launch. - // This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - BatchedGemmCoord problem_shape, - GemmCoord cta_shape, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - - CUTLASS_UNUSED(cluster_shape); - CUTLASS_UNUSED(hw_info); - CUTLASS_UNUSED(max_swizzle_size); - CUTLASS_UNUSED(raster_order_option); - - return get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - } - - // Get the number of CTA tiles in this problem. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl( - BatchedGemmCoord problem_shape, - GemmCoord cta_shape, - GemmCoord cluster_shape) { - - return UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - } - - // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static size_t - get_workspace_size( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle, - RasterOrderOptions raster_order_option - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - return get_workspace_size( - problem_blocks, - cluster_shape, - hw_info, - max_swizzle, - raster_order_option - ); - } - - // Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static size_t - get_workspace_size( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle, - RasterOrderOptions raster_order_option - ) { - - CUTLASS_UNUSED(problem_blocks); - CUTLASS_UNUSED(cluster_shape); - CUTLASS_UNUSED(hw_info); - CUTLASS_UNUSED(max_swizzle); - CUTLASS_UNUSED(raster_order_option); - - return 0; - } - - // Initialize the workspace to be used for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle, - RasterOrderOptions raster_order_option, - CudaHostAdapter *cuda_adapter = nullptr - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - return initialize_workspace( - workspace, - stream, - problem_blocks, - cluster_shape, - hw_info, - max_swizzle, - raster_order_option, - cuda_adapter - ); - } - - // Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle, - RasterOrderOptions raster_order_option, - CudaHostAdapter *cuda_adapter = nullptr - ) { - - CUTLASS_UNUSED(workspace); - CUTLASS_UNUSED(stream); - CUTLASS_UNUSED(problem_blocks); - CUTLASS_UNUSED(cluster_shape); - CUTLASS_UNUSED(hw_info); - CUTLASS_UNUSED(max_swizzle); - CUTLASS_UNUSED(raster_order_option); - - return cutlass::Status::kSuccess; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Parameters for SM100 persistent stream-K tile scheduler -struct PersistentTileSchedulerSm100StreamKParams { - using UnderlyingParams = PersistentTileSchedulerSm100Params; - using UnderlyingStreamKParams = PersistentTileSchedulerSm90StreamKParams; - using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; - using ReductionMode = UnderlyingStreamKParams::ReductionMode; - using DecompositionMode = UnderlyingStreamKParams::DecompositionMode; - - using RasterOrder = UnderlyingParams::RasterOrder; - RasterOrder raster_order_ = RasterOrder::AlongM; - int32_t log_swizzle_size_ = 0; - - UnderlyingStreamKParams sk_params_{}; - UnderlyingParams sm100_params_{}; - - // Initializes members. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - void - initialize( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - ReductionMode reduction_mode, - DecompositionMode decomposition_mode, - void* workspace, - uint32_t ktile_start_alignment_count = 1u - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - - // Number of k tiles in each output tile - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - initialize( - problem_blocks, - k_tiles_per_output_tile, - cluster_shape, - hw_info, - splits, - max_swizzle_size, - raster_order_option, - reduction_mode, - decomposition_mode, - workspace, - ktile_start_alignment_count - ); - } - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - uint32_t k_tile_per_output_tile, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - ReductionMode reduction_mode, - DecompositionMode decomposition_mode, - void* workspace, - uint32_t ktile_start_alignment_count = 1u - ) { - sk_params_.initialize( - problem_blocks, - k_tile_per_output_tile, - cluster_shape, - hw_info, - splits, - max_swizzle_size, - raster_order_option, - reduction_mode, - decomposition_mode, - workspace, - /*epilogue_subtile=*/1, - ktile_start_alignment_count, - /*bypass_sm90_occupancy_calculation=*/true - ); - - log_swizzle_size_ = sk_params_.log_swizzle_size_; - raster_order_ = sk_params_.raster_order_; - - sm100_params_.initialize( - problem_blocks, - cluster_shape, - hw_info, - 0, // Override max_swizzle_size to be 0, since the SM100 stream-K scheduler handles swizzling on its own - RasterOrderOptions::AlongM // Override raster_order to be AlongM, since the SM100 stream-K scheduler does not require grid swapping for raster order selection - ); - } - - // Get the number of CTA tiles in this problem. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl( - BatchedGemmCoord problem_shape, - GemmCoord cta_shape, - GemmCoord cluster_shape) { - - return UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - } - - // Given the inputs, computes the physical grid we should launch. - // This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - CUTLASS_HOST_DEVICE - dim3 - get_grid_shape(BatchedGemmCoord problem_shape, GemmCoord cta_shape, GemmCoord cluster_shape) const { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); - - return get_grid_shape(problem_blocks, cluster_shape); - } - - // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE - dim3 - get_grid_shape(dim3 problem_blocks, GemmCoord cluster_shape) const { - if (sk_params_.sk_units_ > 0) { - // For stream-K cases, we would, ideally, launch a linear grid of size `sk_params_.units_per_problem_`. - // However doing so raises two potential issues: - // (a) the total number of tiles in the kernel may exceed the amount that can fit in a single - // returned value of a CLC query - // (b) the launched grid would not respect cluster-size divisibility requirements - // - // To circumvent these issues, we must distribute the `sk_params_.units_per_problem_` units of work - // across the X, Y, and Z dimensions of the grid, while ensuring that the X and Y dimensions are - // divisible by cluster size (we ignore Z, as all CUTLASS kernels currently use a cluster shape - // of 1 in the Z dimension). - // - // For convenience, we launch this as "waves" of `sk_params_.sk_units_` CTAs, with the wave count being - // the Z dimension of the grid, and the `sk_params_.sk_units_` CTAs per wave being distributed across - // the X and Y dimensions of the grid in a way that alingns with cluster divisibility requirements. - // - // Thus, the grid that is launched looks like: - // grid = dim3(sk_units_ / cluster.y, cluster.y, waves) - // - // We place sk_units_ / cluster.y in the X dimension of the grid because the CLC query feature - // allocates more bits for the X index values returned in the query. - // - - // For most cases, `sk_params_.sk_units_` will equal the number of available SMs, so this grid will - // naturally represent waves in the true hardware sense. - // - // However, there are some corner cases in which fewer stream-K units are used than the full SM count - // (e.g., if using the full SM count would result in stream-K units that are assigned fewer than the - // minimum number of K tile iterations). In these cases, `sk_params_.units_per_problem_` may not be - // divisible by `sk_params_.sk_units_`, since any data-parallel work performed alongside stream-K - // work is always done in terms of waves of CTAs of number equal to the number of available SMs. - // Therefore, we take the ceiling of the division when determining wave count, and allow the underlying - // stream-K scheduler to determine which indices are in bounds. - uint32_t waves = static_cast( - (sk_params_.units_per_problem_ + sk_params_.sk_units_ - 1) / sk_params_.sk_units_); - - return dim3( - sk_params_.sk_units_ / cluster_shape.n(), - cluster_shape.n(), - waves - ); - } - else { - // Grid launch for data-parallel and basic split-K decomposition. When data-parallel - // mode is used, params.sk_params_.splits = 1. - return dim3(problem_blocks.x, problem_blocks.y, problem_blocks.z * sk_params_.divmod_splits_.divisor); - } - } - - // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static size_t - get_workspace_size( - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t reduction_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t ktile_start_alignment_count = 1 - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - return get_workspace_size( - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - reduction_warp_groups, - barrier_bits, - element_accumulator_bits, - ktile_start_alignment_count - ); - } - - // Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static size_t - get_workspace_size( - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t reduction_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - uint32_t ktile_start_alignment_count = 1 - ) { - return UnderlyingStreamKParams::get_workspace_size( - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - reduction_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - ktile_start_alignment_count, - /*bypass_sm90_occupancy_calculation=*/true - ); - } - - // Initialize the workspace to be used for the kernel. This variant of the method should only be used when - // problem_shape and tile_shape contain modes of only rank 1. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - BatchedGemmCoord problem_shape, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t reduction_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter *cuda_adapter = nullptr, - uint32_t ktile_start_alignment_count = 1 - ) { - dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); - uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); - - return initialize_workspace( - workspace, - stream, - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - reduction_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - cuda_adapter, - ktile_start_alignment_count - ); - } - - // Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - static cutlass::Status - initialize_workspace( - void* workspace, - cudaStream_t stream, - dim3 problem_blocks, - uint32_t k_tiles_per_output_tile, - GemmCoord tile_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int splits, - int max_swizzle, - RasterOrderOptions raster_order_option, - DecompositionMode decomposition_mode, - ReductionMode reduction_mode, - uint32_t reduction_warp_groups, - uint32_t barrier_bits, - uint32_t element_accumulator_bits, - uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter *cuda_adapter = nullptr, - uint32_t ktile_start_alignment_count = 1 - ) { - return UnderlyingStreamKParams::initialize_workspace( - workspace, - stream, - problem_blocks, - k_tiles_per_output_tile, - tile_shape, - cluster_shape, - hw_info, - splits, - max_swizzle, - raster_order_option, - decomposition_mode, - reduction_mode, - reduction_warp_groups, - barrier_bits, - element_accumulator_bits, - epilogue_subtile, - num_accumulator_mtxs, - cuda_adapter, - ktile_start_alignment_count, - /*bypass_sm90_occupancy_calculation=*/true - ); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// Parameters for SM100 persistent group scheduler (only used for Grouped Gemms) -template -struct PersistentTileSchedulerSm100GroupParams { - - using UnderlyingSm90Params = PersistentTileSchedulerSm90GroupParams; - using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; - using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; - - UnderlyingSm90Params params_sm90_{}; - - // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - void - initialize( - dim3 problem_blocks, - GroupProblemShape problem_shapes, - GemmCoord cta_shape, - GemmCoord cluster_shape, - KernelHardwareInfo const& hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option - ) { - - params_sm90_.initialize( - problem_blocks, - problem_shapes, - cta_shape, - cluster_shape, - hw_info, - max_swizzle_size, - raster_order_option - ); - } - - // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE - static dim3 - get_tiled_cta_shape_mnl(GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) { - return UnderlyingSm90Params::get_tiled_cta_shape_mnl(cluster_shape, cta_m, cta_n); - } - - // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. - // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, - // for which using CuTe algebra for calculating tile shapes is easiest. - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - dim3 problem_blocks, - GemmCoord cluster_shape, - KernelHardwareInfo hw_info, - int max_swizzle_size, - RasterOrderOptions raster_order_option, - bool truncate_by_problem_size = true, - bool is_static_cluster_shape = false) { - - int const sm_count = hw_info.sm_count; - int const max_active_clusters = hw_info.max_active_clusters; - - // Round up to nearest multiple of swizzle_size along each mode - auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); - - int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; - - RasterOrder raster_order = get_rasterization_order( - problem_blocks_m, - problem_blocks_n, - raster_order_option - ); - - dim3 launch_grid; - - if (raster_order == RasterOrder::AlongN) { - launch_grid = dim3(cluster_shape.m(), 1, 1); - } - else { - launch_grid = dim3(1, cluster_shape.n(), 1); - } - - auto possibly_truncate = [&](int x, int y) { - if (truncate_by_problem_size) { - return platform::min(x, y); - } - else { - return x; - } - }; - - if (is_static_cluster_shape) { - // The else path is generic, however, we can avoid some divs if we know cluster size is 1 - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - if (cluster_size == 1) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); - } - else { - launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); - } - } - // In case the maximum number of clusters that could co-exist on the target device is - // already calculated using cudaOccupancyMaxActiveClusters - else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = max_active_clusters * cluster_shape.n(); - } - else { - launch_grid.x = max_active_clusters * cluster_shape.m(); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - else { - constexpr int max_sm_per_gpc = 20; - int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = possibly_truncate( - cta_per_device / cluster_shape.m(), - problem_blocks_total / cluster_shape.m()); - } - else { - launch_grid.x = possibly_truncate( - cta_per_device / cluster_shape.n(), - problem_blocks_total / cluster_shape.n()); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using heuristics = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - } - else { - // With preferred clusters, we can launch the largest possible persistent grid (rounded up to cluster dims) - if (raster_order == RasterOrder::AlongN) { - launch_grid.y = ((possibly_truncate(sm_count, problem_blocks_total) / cluster_shape.m()) / cluster_shape.n()) * cluster_shape.n(); - } - else { - launch_grid.x = ((possibly_truncate(sm_count, problem_blocks_total) / cluster_shape.n()) / cluster_shape.m()) * cluster_shape.m(); - } - CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using preferred clusters = " - "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); - } - return launch_grid; - } - - CUTLASS_HOST_DEVICE - static int32_t - get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - return UnderlyingSm90Params::get_log_swizzle_size(problem_ctas_m, problem_ctas_n, max_swizzle_size); - } - - CUTLASS_HOST_DEVICE - static RasterOrder - get_rasterization_order( - uint32_t tiles_m, - uint32_t tiles_n, - RasterOrderOptions raster_order_option - ) { - return UnderlyingSm90Params::get_rasterization_order(tiles_m, tiles_n, raster_order_option); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - - -} // namespace detail -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/trmm_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/trmm_universal.h deleted file mode 100644 index 992aa484ff8e789b037fced736af1baa8b93502c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/trmm_universal.h +++ /dev/null @@ -1,580 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" -#include "cutlass/semaphore.h" -#include "cutlass/core_io.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) - FillMode FillMode_, ///! Fill Mode for triangular matrix (kLower or kUpper) - DiagType DiagType_ ///! Diag Type for triangular matrix (kNonUnit or kUnit) -> -struct TrmmUniversal { -public: - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static SideMode const kSideMode = SideMode_; - static FillMode const kFillMode = FillMode_; - static DiagType const kDiagType = DiagType_; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); - - // - // Structures - // - - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmUniversalMode mode{GemmUniversalMode::kGemm}; - GemmCoord problem_size{}; - int batch_count{1}; - - typename EpilogueOutputOp::Params epilogue{}; - - void const * ptr_A{nullptr}; - void const * ptr_B{nullptr}; - void * ptr_D{nullptr}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_D{0}; - - typename LayoutA::Stride::Index lda{0}; - typename LayoutB::Stride::Index ldb{0}; - typename LayoutC::Stride::Index ldd{0}; - - // - // Methods - // - - Arguments() = default; - - /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void * ptr_D, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_D, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldd - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldd(ldd) { - } - - /// Returns arguments for the transposed problem sizes - Arguments transposed_problem_size() const { - Arguments args(*this); - - std::swap(args.problem_size.m(), args.problem_size.n()); - - return args; - } - - /// Returns arguments for the transposed matrices - Arguments swapped_matrices() const { - Arguments args(*this); - - std::swap(args.ptr_A, args.ptr_B); - std::swap(args.lda, args.ldb); - std::swap(args.batch_stride_A, args.batch_stride_B); - - return args; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - - cutlass::gemm::GemmCoord problem_size{}; - cutlass::gemm::GemmCoord grid_tiled_shape{}; - int swizzle_log_tile{0}; - - typename Mma::IteratorA::Params params_A{}; - typename Mma::IteratorB::Params params_B{}; - typename Epilogue::OutputTileIterator::Params params_D{}; - - typename EpilogueOutputOp::Params output_op{}; - - GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; - int batch_count {0}; - int gemm_k_size {0}; - - void * ptr_A{nullptr}; - void * ptr_B{nullptr}; - void * ptr_D{nullptr}; - - int64_t batch_stride_A {0}; - int64_t batch_stride_B {0}; - int64_t batch_stride_D {0}; - - int *semaphore{nullptr}; - - // - // Methods - // - Params() = default; - - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.lda), - params_B(args.ldb), - params_D(args.ldd), - output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_D(args.ptr_D), - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_D(args.batch_stride_D), - semaphore(static_cast(workspace)) { - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { - - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_D = args.ptr_D; - - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_D = args.batch_stride_D; - - output_op = args.epilogue; - - semaphore = static_cast(workspace); - } - - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: - - // - // Methods - // - - CUTLASS_DEVICE - TrmmUniversal() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) { - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { - - return Status::kErrorMisalignedOperand; - } - - return Status::kSuccess; - } - - static Status can_implement(Arguments const &args) { - return can_implement(args.problem_size); - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA *ptr_A = static_cast(params.ptr_A); - ElementB *ptr_B = static_cast(params.ptr_B); - - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } - - __syncthreads(); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{ - offset_k, - threadblock_tile_offset.n() * Mma::Shape::kN - }; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - /****************************************************************************************************** - First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other - - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations - needed to process all elements till that coordinate. - - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations - needed to process all elements till that coordinate. - - Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other - - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations - that can be skipped for all elements of this tile. - - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations - that can be skipped for all elements of this tile. - ********************************************************************************************************/ - - if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kLower) { - - int k_iterations_till_diagonal = ((threadblock_tile_offset.m() + 1) * Mma::Shape::kM + Mma::Shape::kK - 1) / Mma::Shape::kK; - if (k_iterations_till_diagonal < gemm_k_iterations) { - gemm_k_iterations = k_iterations_till_diagonal; - } - - } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kUpper) { - - int k_iterations_till_diagonal = ((threadblock_tile_offset.n() + 1) * Mma::Shape::kN + Mma::Shape::kK - 1) / Mma::Shape::kK; - if (k_iterations_till_diagonal < gemm_k_iterations) { - gemm_k_iterations = k_iterations_till_diagonal; - } - - } else if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kUpper) { - - int k_iterations_till_diagonal = ((threadblock_tile_offset.m()) * Mma::Shape::kM) / Mma::Shape::kK; - - if (k_iterations_till_diagonal != 0) { - tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); - tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); - gemm_k_iterations -= k_iterations_till_diagonal; - } - - } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kLower) { - - int k_iterations_till_diagonal = ((threadblock_tile_offset.n()) * Mma::Shape::kN) / Mma::Shape::kK; - - if (k_iterations_till_diagonal != 0) { - tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); - tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); - gemm_k_iterations -= k_iterations_till_diagonal; - } - - } - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); - - // Compute threadblock-scoped matrix multiply-add - mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC *ptr_D = static_cast(params.ptr_D); - - // - // Fetch pointers based on mode. - // - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - if (params.mode == GemmUniversalMode::kGemm) { - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - } - else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kBatched) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; - } - - - // Tile iterator loading from source tensor (although irrelevant to this kernel as beta is zero). - typename Epilogue::OutputTileIterator iterator_C( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - - __threadfence(); - } - - - // Execute the epilogue operator to update the destination tensor. - epilogue( - output_op, - iterator_D, - accumulators, - iterator_C); - - // - // Release the semaphore - // - - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma.h deleted file mode 100644 index 018963b260979d771d86070cfc79c989a710a059..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma.h +++ /dev/null @@ -1,90 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for warp-level multiply-add operations -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/mma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Concept: arch::OpMultiplyAdd or arch::Mma<> - typename Operator = arch::OpMultiplyAdd, - /// Used for partial specialization - typename Enable = bool -> -struct Mma; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Overloads specialized for existing architectures -// - -#include "cutlass/gemm/thread/mma_sm50.h" -#include "cutlass/gemm/thread/mma_sm60.h" -#include "cutlass/gemm/thread/mma_sm61.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm50.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm50.h deleted file mode 100644 index e05c56e3081ea2bb9ac72051c1a22f46394ff6ee..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm50.h +++ /dev/null @@ -1,540 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/thread/mma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gemplate that handles all packed matrix layouts -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: layout::MapFunc) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: layout::MapFunc) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: layout::MapFunc) - typename LayoutC_, - /// Operator used to compute GEMM - typename Operator_ -> -struct MmaGeneric { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = ElementA_; - - /// Layout of A matrix (concept: layout::MapFunc) - using LayoutA = LayoutA_; - - /// Data type of operand B - using ElementB = ElementB_; - - /// Layout of B matrix (concept: layout::MapFunc) - using LayoutB = LayoutB_; - - /// Element type of operand C - using ElementC = ElementC_; - - /// Layout of C matrix (concept: layout::MapFunc) - using LayoutC = LayoutC_; - - /// Underlying mathematical operator - using Operator = Operator_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Instruction - using MmaOp = arch::Mma< - gemm::GemmShape<1,1,1>, - 1, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - Operator>; - - static bool const kMultipleOf2 = ((Shape::kM % 2 == 0) && (Shape::kN % 2 == 0)); - - static bool const kAllFp32 = platform::is_same::value && - platform::is_same::value && - platform::is_same::value; - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - TensorRef a_ref( - reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); - - TensorRef b_ref( - reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); - - TensorRef d_ref( - reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); - - MmaOp mma_op; - - // Copy accumulators - D = C; - - // Compute matrix product - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK; ++k) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) - if constexpr (kMultipleOf2 && kAllFp32) { - //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; n+=2) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; m+=2) { - - int m_serpentine = (n % 4) ? (Shape::kM - 2 - m) : m; - - //top-left element in 2x2 tile - { - MatrixCoord mn(m_serpentine, n); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n); - Array d; - Array a; - Array b; - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - mma_op(d, a, b, d); - d_ref.at(mn) = d[0]; - } - - //bottom-left element in 2x2 tile - { - MatrixCoord mn(m_serpentine+1, n); - MatrixCoord mk(m_serpentine+1, k); - MatrixCoord kn(k, n); - Array d; - Array a; - Array b; - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - mma_op(d, a, b, d); - d_ref.at(mn) = d[0]; - } - - //bottom-right element in 2x2 tile - { - MatrixCoord mn(m_serpentine+1, n+1); - MatrixCoord mk(m_serpentine+1, k); - MatrixCoord kn(k, n+1); - Array d; - Array a; - Array b; - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - mma_op(d, a, b, d); - d_ref.at(mn) = d[0]; - } - - //top-right element in 2x2 tile - { - MatrixCoord mn(m_serpentine, n+1); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n+1); - Array d; - Array a; - Array b; - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - mma_op(d, a, b, d); - d_ref.at(mn) = d[0]; - } - } - } - } else - #endif - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { - - int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; - - MatrixCoord mn(m_serpentine, n); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n); - - Array d; - Array a; - Array b; - - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - - mma_op(d, a, b, d); - - d_ref.at(mn) = d[0]; - } - } - } - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Matrix multiply-add operation - assumes operand B is not changing -struct MmaComplexF32_Column { - - using Shape = gemm::GemmShape<1, 1, 1>; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = a[0].real() * b[0].real() + c[0].real(); - d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); - d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); - d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); - } -}; - -/// Matrix multiply-add operation - assumes operand A is not changing -struct MmaComplexF32_Corner { - - using Shape = gemm::GemmShape<1, 1, 1>; - using ElementC = complex; - - CUTLASS_HOST_DEVICE - void operator()( - Array, 1> &d, - Array, 1> const &a, - Array, 1> const &b, - Array, 1> const &c - ) { - - d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); - d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); - d[0].real() = a[0].real() * b[0].real() + c[0].real(); - d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); - } -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gemplate that handles all packed matrix layouts -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of A matrix (concept: layout::MapFunc) - typename LayoutA_, - /// Layout of B matrix (concept: layout::MapFunc) - typename LayoutB_, - /// Layout of C matrix (concept: layout::MapFunc) - typename LayoutC_ -> -struct MmaGeneric< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - arch::OpMultiplyAdd> { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = complex; - - /// Layout of A matrix (concept: layout::MapFunc) - using LayoutA = LayoutA_; - - /// Data type of operand B - using ElementB = complex; - - /// Layout of B matrix (concept: layout::MapFunc) - using LayoutB = LayoutB_; - - /// Element type of operand C - using ElementC = complex; - - /// Layout of C matrix (concept: layout::MapFunc) - using LayoutC = LayoutC_; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Instruction - using MmaOp = arch::Mma< - gemm::GemmShape<1,1,1>, - 1, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - Operator>; - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - TensorRef a_ref( - reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); - - TensorRef b_ref( - reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); - - TensorRef d_ref( - reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); - - detail::MmaComplexF32_Column mma_column; - detail::MmaComplexF32_Corner mma_corner; - - // Copy accumulators - D = C; - - // Compute matrix product - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK; ++k) { - - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { - - int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; - - MatrixCoord mn(m_serpentine, n); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n); - - Array d; - Array a; - Array b; - - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); - - if ((m == 0 && n) || m == Shape::kM - 1) { - mma_corner(d, a, b, d); - } - else { - mma_column(d, a, b, d); - } - - d_ref.at(mn) = d[0]; - } - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gemplate that handles conventional layouts for FFMA and DFMA GEMM -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: layout::MapFunc) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: layout::MapFunc) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: layout::MapFunc) - typename LayoutC_ -> -struct Mma< - Shape_, - ElementA_, - LayoutA_, - ElementB_, - LayoutB_, - ElementC_, - LayoutC_, - arch::OpMultiplyAdd, - bool> { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = ElementA_; - - /// Layout of A matrix (concept: layout::MapFunc) - using LayoutA = LayoutA_; - - /// Data type of operand B - using ElementB = ElementB_; - - /// Layout of B matrix (concept: layout::MapFunc) - using LayoutB = LayoutB_; - - /// Element type of operand C - using ElementC = ElementC_; - - /// Layout of C matrix (concept: layout::MapFunc) - using LayoutC = LayoutC_; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename MmaGeneric< - Shape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator>::MmaOp; - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - MmaGeneric< - Shape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator> mma; - - mma(D, A, B, C); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm60.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm60.h deleted file mode 100644 index 64c8e033af3f60d3c85f642ade9ad2b43797146c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm60.h +++ /dev/null @@ -1,1161 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/thread/mma.h" -#include "cutlass/functional.h" -#include "cutlass/reduction/thread/reduce.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Structure to compute the matrix product for HFMA -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - - /// Type of GEMM inner vs outer product - bool -> -struct Mma_HFMA2; - - -///////////////////////////// -// Specialization for NNN // -///////////////////////////// - -template -struct Mma_HFMA2 < - Shape_, - layout::ColumnMajor, - layout::ColumnMajor, - layout::ColumnMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x1x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<2,1,1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - Array tmp { ptr_D[n*Shape::kM/2 + m] }; - - mma( - tmp, - ptr_A[k*Shape::kM/2 + m], - ptr_B[n*Shape::kK + k], - tmp); - - ptr_D[n*Shape::kM/2 + m] = tmp; - } - } - } - } -}; - -///////////////////////////// -// Specialization for NNT // -///////////////////////////// - -template -struct Mma_HFMA2< - Shape_, - layout::ColumnMajor, - layout::ColumnMajor, - layout::RowMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x2x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,2,1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - Array tmp { ptr_D[m*Shape::kN/2 + n] }; - - Array tmp_B; - tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); - tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); - - mma( - tmp, - ptr_A[k*Shape::kM + m], - tmp_B, - tmp); - - ptr_D[m*Shape::kN/2 + n] = tmp; - } - } - } - } -}; - - -///////////////////////////// -// Specialization for NTN // -///////////////////////////// - -template -struct Mma_HFMA2 < - Shape_, - layout::ColumnMajor, - layout::RowMajor, - layout::ColumnMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - using Mma = arch::Mma< - gemm::GemmShape<2,1,1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) { - - Array tmp { ptr_D[m + n * Shape::kM/2] }; - - mma( - tmp, - ptr_A[m + k * Shape::kM/2], - ptr_B[k * Shape::kN + n], - tmp); - - ptr_D[m + n * Shape::kM/2] = tmp; - } - } - } - } -}; - -///////////////////////////// -// Specialization for NTT // -///////////////////////////// - -template -struct Mma_HFMA2< - Shape_, - layout::ColumnMajor, - layout::RowMajor, - layout::RowMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x2x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,2,1>, - 1, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - Array tmp { ptr_D[m*Shape::kN/2 + n] }; - - mma( - tmp, - ptr_A[k*Shape::kM + m], - ptr_B[k*Shape::kN/2 + n], - tmp); - - ptr_D[m*Shape::kN/2 + n] = tmp; - } - } - } - } -}; - - -///////////////////////////// -// Specialization for TNN // -///////////////////////////// - -template -struct Mma_HFMA2 < - Shape_, - layout::RowMajor, - layout::ColumnMajor, - layout::ColumnMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x1x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<2,1,1>, - 1, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::ColumnMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - Array tmp { ptr_D[n*Shape::kM/2 + m] }; - - Array tmp_A; - tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); - tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); - - mma( - tmp, - tmp_A, - ptr_B[n*Shape::kK + k], - tmp); - - ptr_D[n*Shape::kM/2 + m] = tmp; - } - } - } - } -}; - -///////////////////////////// -// Specialization for TNT // -///////////////////////////// - -template -struct Mma_HFMA2 < - Shape_, - layout::RowMajor, - layout::ColumnMajor, - layout::RowMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x2x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,2,1>, - 1, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - half_t, - layout::RowMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - Array tmp { ptr_D[m*Shape::kN/2 + n] }; - - Array tmp_B; - tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); - tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); - - mma( - tmp, - ptr_A[m*Shape::kK + k], - tmp_B, - tmp); - - ptr_D[m*Shape::kN/2 + n] = tmp; - } - } - } - } -}; - -///////////////////////////// -// Specialization for TTN // -///////////////////////////// - -template -struct Mma_HFMA2 < - Shape_, - layout::RowMajor, - layout::RowMajor, - layout::ColumnMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x2x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<2,1,1>, - 1, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - half_t, - layout::ColumnMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - Array tmp { ptr_D[n*Shape::kM/2 + m] }; - - Array tmp_A; - tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); - tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); - - mma( - tmp, - tmp_A, - ptr_B[k*Shape::kN + n], - tmp); - - ptr_D[n*Shape::kM/2 + m] = tmp; - } - } - } - } -}; - - -///////////////////////////// -// Specialization for TTT // -///////////////////////////// - -template -struct Mma_HFMA2< - Shape_, - layout::RowMajor, - layout::RowMajor, - layout::RowMajor, - true - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x2x1 HFMA2 sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,2,1>, - 1, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - half_t, - layout::RowMajor, - arch::OpMultiplyAdd>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Mma mma; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - - Array tmp { ptr_D[m*Shape::kN/2 + n] }; - - mma( - tmp, - ptr_A[m*Shape::kK + k], - ptr_B[k*Shape::kN/2 + n], - tmp); - - ptr_D[m*Shape::kN/2 + n] = tmp; - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////// -// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // -///////////////////////////////////////////////////////////////////// - -template -struct Mma_HFMA2< - Shape_, - LayoutA, - LayoutB, - layout::RowMajor, - false - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kK % 2), - "Mma_HFMA2 requires the K dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x1x2 HFMA2 sequence for bulk of computation - using GemmShape = gemm::GemmShape<1,1,2>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - // Inner product is calculated using MACs, followed by final reduction - multiply_add> mac; - cutlass::reduction::thread::Reduce< plus, Array > reduce; - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ - - Array tmp_C; - tmp_C.clear(); - Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); - ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ - tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); - } - - Array res; - Array *ptr_res = &res; - res = reduce(tmp_C); - - ptr_D[m*Shape::kN + n] = ptr_res[0]; - } - } - } -}; - -///////////////////////////////////////////////////////////////////// -// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // -///////////////////////////////////////////////////////////////////// - -template -struct Mma_HFMA2< - Shape_, - LayoutA, - LayoutB, - layout::ColumnMajor, - false - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - static_assert( - !(Shape::kK % 2), - "Mma_HFMA2 requires the K dimension to be divisible by 2." - ); - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - /// Initialize output with input - D = C; - - /// Use 1x1x2 HFMA2 sequence for bulk of computation - using GemmShape= gemm::GemmShape<1,1,2>; - - Array *ptr_D = reinterpret_cast *>(&D); - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - // Inner product is calculated using MACs, followed by final reduction - multiply_add> mac; - cutlass::reduction::thread::Reduce< plus, Array > reduce; - - CUTLASS_PRAGMA_UNROLL - for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ - - CUTLASS_PRAGMA_UNROLL - for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ - - Array tmp_C; - tmp_C.clear(); - Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); - ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; - - CUTLASS_PRAGMA_UNROLL - for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ - - tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); - - } - - Array res; - Array *ptr_res = &res; - res = reduce(tmp_C); - - ptr_D[n*Shape::kM + m] = ptr_res[0]; - } - } - } -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC -> -struct Mma< - Shape_, - half_t, - LayoutA, - half_t, - LayoutB, - half_t, - LayoutC, - arch::OpMultiplyAdd - > { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = half_t; - - /// Data type of operand B - using ElementB = half_t; - - /// Element type of operand C - using ElementC = half_t; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value; - static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value; - static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value; - static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value; - - static bool const m_mod2 = !(Shape::kM % 2); - static bool const n_mod2 = !(Shape::kN % 2); - static bool const k_mod2 = !(Shape::kK % 2); - - // HFMA based MMA optimizations are of 2 types : - // 1. Inner product - // 2. Outer product - // It is chosen based on LayoutC (for outer product gemm) or - // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms) - // If all fails, we choose the generic MMA - static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2); - static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2); - static bool const use_optimized = (use_outer_prod || use_inner_prod); - - using ArchMmaOperator = typename platform::conditional< use_optimized, - detail::Mma_HFMA2, - MmaGeneric - >::type; - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - ArchMmaOperator mma; - - mma(D, A, B, C); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - - /// Determines whether to enable thread::Gemm<> specializations compatible with SM50 - template < - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB> - struct EnableMma_Crow_SM60 { - - static bool const kIsConventionalLayout = - (platform::is_same::value || - platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value); - - static bool const value = kIsConventionalLayout; - }; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes matrix product when C is row-major -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - typename LayoutA_, - typename LayoutB_ -> -struct Mma< - Shape_, - half_t, - LayoutA_, - half_t, - LayoutB_, - half_t, - layout::RowMajor, - arch::OpMultiplyAdd, - typename platform::enable_if::value>::type>{ - - using Shape = Shape_; - using ElementA = half_t; - using LayoutA = LayoutA_; - using ElementB = half_t; - using LayoutB = LayoutB_; - using ElementC = half_t; - using LayoutC = layout::RowMajor; - using Operator = arch::OpMultiplyAdd; - - using TransposeMma = Mma< - GemmShapeTranspose, - half_t, - typename layout::LayoutTranspose::type, - half_t, - typename layout::LayoutTranspose::type, - half_t, - layout::ColumnMajor, - arch::OpMultiplyAdd, - bool>; - - using FragmentA = Array; - using FragmentB = Array; - using FragmentC = Array; - - using ArchMmaOperator = typename TransposeMma::ArchMmaOperator; - - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - TransposeMma mma; - - mma(D, B, A, C); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm61.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm61.h deleted file mode 100644 index f7127ed842133a147db2f1cdeaa700ce3d69dc90..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/thread/mma_sm61.h +++ /dev/null @@ -1,284 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/thread/mma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gemplate that handles conventional layouts for IDP4A -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_ -> -struct Mma< - Shape_, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int32_t, - LayoutC_, - arch::OpMultiplyAdd, - bool> { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = int8_t; - - /// Layout of A matrix (concept: layout::MapFunc) - using LayoutA = layout::RowMajor; - - /// Data type of operand B - using ElementB = int8_t; - - /// Layout of B matrix (concept: layout::MapFunc) - using LayoutB = layout::ColumnMajor; - - /// Element type of operand C - using ElementC = int32_t; - - /// Layout of C matrix (concept: layout::MapFunc) - using LayoutC = LayoutC_; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying matrix multiply operator (concept: arch::Mma) - // Use 1x1x4 IDP4A sequence for bulk of computation - using ArchMmaOperator = arch::Mma< - gemm::GemmShape<1,1,4>, - 1, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - arch::OpMultiplyAdd>; - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - TensorRef d( - reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); - - // Copy accumulators - D = C; - - /// Use 1x1x4 IDP4A sequence for bulk of computation - ArchMmaOperator mma; - - // Compute matrix product - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { - MatrixCoord mn(m, n); - - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - Array tmp = reinterpret_cast &>(d.at(mn)); - - mma( - tmp, - ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k], - ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k], - tmp); - - d.at(mn) = reinterpret_cast(tmp); - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Gemplate that handles conventional layouts for IDP4A -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_ -> -struct Mma< - Shape_, - int8_t, - layout::ColumnMajor, - int8_t, - layout::RowMajor, - int32_t, - LayoutC_, - arch::OpMultiplyAdd, - int8_t> { - - /// Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - /// Data type of operand A - using ElementA = int8_t; - - /// Layout of A matrix (concept: layout::MapFunc) - using LayoutA = layout::ColumnMajor; - - /// Data type of operand B - using ElementB = int8_t; - - /// Layout of B matrix (concept: layout::MapFunc) - using LayoutB = layout::RowMajor; - - /// Element type of operand C - using ElementC = int32_t; - - /// Layout of C matrix (concept: layout::MapFunc) - using LayoutC = LayoutC_; - - /// Underlying mathematical operator - using Operator = arch::OpMultiplyAdd; - - /// A operand storage - using FragmentA = Array; - - /// B operand storage - using FragmentB = Array; - - /// C operand storage - using FragmentC = Array; - - /// Underlying matrix multiply operator (concept: arch::Mma) - /// Use 1x1x4 IDP4A sequence for bulk of computation - using ArchMmaOperator = arch::Mma< - gemm::GemmShape<1,1,4>, - 1, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - arch::OpMultiplyAdd>; - - // - // Methods - // - - /// Computes a matrix product D = A * B + C - CUTLASS_HOST_DEVICE - void operator()( - FragmentC & D, - FragmentA const & A, - FragmentB const & B, - FragmentC const & C) { - - TensorRef d( - reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); - - // Copy accumulators - D = C; - - /// Underlying matrix multiply operator - ArchMmaOperator mma; - - Array const *ptr_A = reinterpret_cast const *>(&A); - Array const *ptr_B = reinterpret_cast const *>(&B); - - // Compute matrix product - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { - MatrixCoord mn(m, n); - - Array tmp = reinterpret_cast &>(d.at(mn)); - - mma( - tmp, - ptr_A[m + k * Shape::kM], - ptr_B[n + k * Shape::kN], - tmp); - - d.at(mn) = reinterpret_cast(tmp); - } - } - } - } -}; - -} // namespace thread -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h deleted file mode 100644 index 0ae82f32a857315466af13ce485313d6bc67efe0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h +++ /dev/null @@ -1,734 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default template for a Blocked-Ell MMA. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/wmma.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED - -#include "cutlass/gemm/threadblock/ell_mma_pipelined.h" -#include "cutlass/gemm/threadblock/ell_mma_multistage.h" -#include "cutlass/transform/threadblock/ell_predicated_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false - > -struct DefaultEllMma; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass Simt) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, - arch::OpClassSimt, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator - > -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, - arch::OpClassTensorOp, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator - > -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, - LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, - arch::OpMultiplyAddFastF16>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, float, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for column-major-interleaved output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Number of Interleaved K - int InterleavedK> -struct DefaultEllMma, OperatorClass, - ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, - Operator, true> { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, - layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, - true>; - - static_assert(kAlignmentA == 128 / sizeof_bits::value, - "Alignment must match thread data map's vector length"); - - static_assert(kAlignmentB ==128 / sizeof_bits::value, - "Alignment must match thread data map's vector length"); - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, ElementA, - LayoutA, 1, typename MmaCore::IteratorThreadMapA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, ElementB, - LayoutB, 0, typename MmaCore::IteratorThreadMapB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::ColumnMajorInterleaved, - typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, - Stages, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultEllMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for column-major-interleaved output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Number of Interleaved K - int InterleavedK> -struct DefaultEllMma, OperatorClass, - ArchTag, ThreadblockShape, WarpShape, InstructionShape, - Stages, Operator, true> { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, - layout::ColumnMajorInterleaved, OperatorClass, Stages, - Operator, true>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for SIMT IDP4A Kernels -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Operation performed by GEMM - typename Operator, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape> -struct DefaultEllMma, 2, - Operator, false> { - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using ElementB = int8_t; - using OperatorClass = arch::OpClassSimt; - - static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value; - static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, - OperatorClass, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -/// Specialization for Wmma TensorOp operator with 2 staged pipeline -template < - ///< Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, - arch::OpClassWmmaTensorOp, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - LayoutC, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for Wmma TensorOp operator with 1 staged pipeline -template < - ///< Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultEllMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, - arch::OpClassWmmaTensorOp, 1, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::EllPredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped singlestage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - LayoutC, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// -#endif //CUTLASS_ARCH_WMMA_ENABLED - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h deleted file mode 100644 index 214f451c152451d0b78f70bc191cc5ead625286a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h +++ /dev/null @@ -1,151 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level batched GEMV assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting SIMT instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/layout/matrix.h" - -#include "cutlass/platform/platform.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/thread/mma.h" - -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/pitch_linear_thread_map.h" - -#include "cutlass/gemm/threadblock/gemv.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace gemm { -namespace threadblock { - -/// Template defininng default vector-matrix multiply operators inferred from threadblock tile size, -/// global memory data layout. -template < - typename Shape_, /// Shape of the threadblock vector-matrix multiply operator - typename ThreadShape_, /// Shape of per-thread vector-matrix multiply operator - typename ElementA_, /// Element data type of A operand - typename LayoutA_, /// Layout of operand A - typename ElementB_, /// Element data type of B operand - typename LayoutB_, /// Layout of operand B - typename ElementC_, /// Data type of accumulator - typename LayoutC_ /// Layout of accumulator -> -struct DefaultGemvCore { - - using Shape = Shape_; - using ThreadShape = ThreadShape_; - - using LayoutA = LayoutA_; - using LayoutB = LayoutB_; - using LayoutC = LayoutC_; - - using ElementA = ElementA_; - using ElementB = ElementB_; - using ElementC = ElementC_; - - static int const kThreadsPerN = Shape::kN / ThreadShape::kN; - - using IteratorPolicyA = typename platform::conditional< - platform::is_same::value, - cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< - layout::PitchLinearShape, 1, ThreadShape::kK>, - cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< - layout::PitchLinearShape, 1, ThreadShape::kM>>::type; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, IteratorPolicyA>; - - using IteratorPolicyB = typename platform::conditional< - platform::is_same::value, - cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< - layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, - cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< - layout::PitchLinearShape, kThreadsPerN, ThreadShape::kK>>::type; - - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, IteratorPolicyB>; - - using IteratorPolicyC = typename platform::conditional< - platform::is_same::value, - cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< - layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, - cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< - layout::PitchLinearShape, kThreadsPerN, ThreadShape::kM>>::type; - - using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementC, LayoutC, 0, IteratorPolicyC>; - - using MmaSimtOp = typename cutlass::gemm::thread::Mma< - cutlass::gemm::GemmShape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC>; - - using Operator = MmaSimtOp; - - // Assertions for correctness - static_assert((Shape::kM == 1), "M=1 is required for GEMV"); - - static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV"); - - static_assert(Shape::kK % ThreadShape::kK == 0, "Shape::K must be a multiple of ThreadShape::K"); - - static_assert(((ThreadShape::kK == 1) || - (ThreadShape::kK == 2) || - (ThreadShape::kK == 4) || - (ThreadShape::kK == 8) || - (ThreadShape::kK == 16) || - (ThreadShape::kK == 32) - ), - "ThreadShape::K must be a 1, 2, 4, 8, 16 or 32"); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma.h deleted file mode 100644 index ee573dbe8dac3576b8647a8302d7a5fb7b677edb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma.h +++ /dev/null @@ -1,823 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/wmma.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/permute.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Gather operand A by using an index array - bool GatherA = false, - /// Gather operand B by using an index array - bool GatherB = false, - /// Permute operand A - typename PermuteALayout = layout::NoPermute, - /// Permute operand B - typename PermuteBLayout = layout::NoPermute - > -struct DefaultMma; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass Simt) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operand - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout - > -struct DefaultMma { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "simt epilogue must be row major"); - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, - arch::OpClassSimt, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, - GatherA, PermuteALayout>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, - GatherB, PermuteBLayout>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - LayoutC, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout - > -struct DefaultMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, - arch::OpClassTensorOp, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, - GatherA, PermuteALayout>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, - GatherB, PermuteBLayout>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout - > -struct DefaultMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, - LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, - arch::OpMultiplyAddFastF16>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, - GatherA, PermuteALayout>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, - GatherB, PermuteBLayout>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, float, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for column-major-interleaved output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Number of Interleaved K - int InterleavedK> -struct DefaultMma, OperatorClass, - ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, - Operator, true, SharedMemoryClearOption::kNone, false, false, - layout::NoPermute, layout::NoPermute> { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, - layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, - true>; - - static_assert(kAlignmentA == 128 / sizeof_bits::value, - "Alignment must match thread data map's vector length"); - - static_assert(kAlignmentB ==128 / sizeof_bits::value, - "Alignment must match thread data map's vector length"); - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, - LayoutA, 1, typename MmaCore::IteratorThreadMapA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, - LayoutB, 0, typename MmaCore::IteratorThreadMapB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::ColumnMajorInterleaved, - typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operand - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout - > -struct DefaultMma { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "simt epilogue must be row major"); - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt, - Stages, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, LayoutC, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operand - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB, - /// Permute operand A - typename PermuteALayout, - /// Permute operand B - typename PermuteBLayout - > -struct DefaultMma { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "simt epilogue must be row major"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, LayoutC, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for column-major-interleaved output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Number of Interleaved K - int InterleavedK> -struct DefaultMma, OperatorClass, - ArchTag, ThreadblockShape, WarpShape, InstructionShape, - Stages, Operator, true, SharedMemoryClearOption::kNone, - false, false, layout::NoPermute, layout::NoPermute> { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, - layout::ColumnMajorInterleaved, OperatorClass, Stages, - Operator, true>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for SIMT IDP4A Kernels -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Operation performed by GEMM - typename Operator, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape> -struct DefaultMma, 2, - Operator, false, SharedMemoryClearOption::kNone, - false, false, layout::NoPermute, layout::NoPermute> { - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using ElementB = int8_t; - using OperatorClass = arch::OpClassSimt; - - static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value; - static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, - OperatorClass, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - layout::RowMajor, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -/// Specialization for Wmma TensorOp operator with 2 staged pipeline -template < - ///< Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, - arch::OpClassWmmaTensorOp, 2, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - LayoutC, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for Wmma TensorOp operator with 1 staged pipeline -template < - ///< Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, LayoutC, - arch::OpClassWmmaTensorOp, 1, Operator>; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - // Define the threadblock-scoped singlestage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, - LayoutC, typename MmaCore::MmaPolicy>; -}; - -//////////////////////////////////////////////////////////////////////////////// -#endif //CUTLASS_ARCH_WMMA_ENABLED - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h deleted file mode 100644 index 16860880e8d84b95b6134a149730ed3d6a21c2f5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h +++ /dev/null @@ -1,116 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/warp/mma.h" -#include "cutlass/gemm/threadblock/mma_pipelined.h" -#include "cutlass/gemm/threadblock/mma_singlestage.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/arch/mma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template defininng default matrix multiply operators inferred from threadblock tile size, -/// global memory data layout, and target math instruction. -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Number of stages - int Stages = 2, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global, - /// per-element transformation for elements of A - ComplexTransform TransformA = ComplexTransform::kNone, - /// per-element transformation for elements of B - ComplexTransform TransformB = ComplexTransform::kNone, - bool IsComplex = false // (is_complex::value || is_complex::value) -> -struct DefaultMmaCore; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h deleted file mode 100644 index 9c9f3e6f142d6c04768c1b10c904c85de6ce7cd0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h +++ /dev/null @@ -1,1723 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting simt instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - - -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h" - -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -namespace detail { - -// convert a WarpShape which is the whole tile of elements into warp num threads. -// The goal is for each thread's tile of elements to be as square as possible -// for performance (4x4 will be faster than 2x8). -template -constexpr int simt_get_warp_threads_m() { - return (WarpShape::kM > WarpShape::kN) ? 8 : 4; -} - -/// Computes padding in shared memory to perform efficient transpose without bank conflicts. -constexpr int simt_transpose_padding(int threads, int crosswise, int size_in_bits) { - return (size_in_bits >= 32 ? - threads / crosswise / (size_in_bits / 32) : - threads / crosswise * (32 / size_in_bits) - ); -} - -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::ColumnMajor, ElementB_, layout::RowMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::RowMajor, ElementB_, layout::ColumnMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - SmemThreadMapA // was IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB // was IteratorThreadMapA - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, // skew for A matrix to avoid SMEM bank conflicts - MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - SmemThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - static_assert(!(kPaddingM % LaneM), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, // skew for A matrix to avoid SMEM bank conflicts - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::ColumnMajor, ElementB_, layout::ColumnMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - static_assert(!(kPaddingN % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2RowMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2ColumnMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2RowMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2ColumnMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: simt class, for dp4a -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, int8_t, - layout::ColumnMajor, int8_t, layout::RowMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using LayoutA = layout::ColumnMajor; - using ElementB = int8_t; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorInterleaved<4>; - using SmemLayoutB = layout::RowMajorInterleaved<4>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(4, ThreadTileM); - static const int LaneN = cutlass::const_min(4, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 4>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::ColumnMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - PartitionsK /// Number of partitions along K dimension - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization: -// -/// -/// A: Row-major -/// B: Column-major -/// Operator: simt class, for dp4a -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, int8_t, - layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorInterleaved<4>; - using SmemLayoutB = layout::RowMajorInterleaved<4>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - SmemThreadMapA - >; - - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(4, ThreadTileM); - static const int LaneN = cutlass::const_min(4, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 4>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::ColumnMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - PartitionsK /// Number of partitions along K dimension - >; - - static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, kPaddingN>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization: -// -/// -/// A: Row-major -/// B: Row-major -/// Operator: simt class, for dp4a -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, int8_t, - layout::RowMajor, int8_t, layout::RowMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using ElementB = int8_t; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorInterleaved<4>; - using SmemLayoutB = layout::RowMajorInterleaved<4>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - SmemThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(4, ThreadTileM); - static const int LaneN = cutlass::const_min(4, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 4>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::ColumnMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - PartitionsK /// Number of partitions along K dimension - >; - - static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization: -// -/// -/// A: Column-major -/// B: Column-major -/// Operator: simt class, for dp4a -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, int8_t, - layout::ColumnMajor, int8_t, layout::ColumnMajor, ElementC_, - LayoutC_, arch::OpClassSimt, 2, Operator_ - > { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 4>; - using ElementA = int8_t; - using LayoutA = layout::ColumnMajor; - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorInterleaved<4>; - using SmemLayoutB = layout::RowMajorInterleaved<4>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 4> - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - SmemThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(4, ThreadTileM); - static const int LaneN = cutlass::const_min(4, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 4>; - - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::ColumnMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - PartitionsK /// Number of partitions along K dimension - >; - - static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, kPaddingN>, - WarpCount::kK - >; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h deleted file mode 100644 index fafc45c029b0bf7198231f1a7a6e2baddb8c122e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h +++ /dev/null @@ -1,682 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - - -#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h" - -#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::ColumnMajor, ElementB_, layout::RowMajor, - ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = - layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< - sizeof_bits::value>; - - // Shared memory layout - using SmemLayoutB = - layout::RowMajorVoltaTensorOpMultiplicandBCongruous< - sizeof_bits::value>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - cutlass::gemm::GemmShape<16, 16, 4>, - 32, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - cutlass::layout::RowMajor, - cutlass::arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::RowMajor, ElementB_, layout::ColumnMajor, - ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 8>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 8>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 1, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - cutlass::gemm::GemmShape<16, 16, 4>, - 32, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - cutlass::layout::RowMajor, - cutlass::arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, - LayoutC_, arch::OpClassTensorOp, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::RowMajorVoltaTensorOpMultiplicandBCongruous< - sizeof_bits::value>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 8>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - cutlass::gemm::GemmShape<16, 16, 4>, - 32, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - cutlass::layout::RowMajor, - cutlass::arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::ColumnMajor, ElementB_, layout::ColumnMajor, - ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< - sizeof_bits::value>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<4, 8>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 1, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - cutlass::gemm::GemmShape<16, 16, 4>, - 32, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - cutlass::layout::RowMajor, - cutlass::arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h deleted file mode 100644 index 39422ec8e20838861c19f5510aa60e6414972632..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h +++ /dev/null @@ -1,1315 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" - -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - using SmemLayoutA = - layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 1, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Below is for arch::OpMultiplyAddFastF16 - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = float; - using LayoutA = layout::ColumnMajor; - using ElementB = float; - using LayoutB = layout::RowMajor; - using ElementC = float; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 256; - - /// Default Operator - using Operator = arch::OpMultiplyAdd; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, int(128 / sizeof(half_t))>; - - // Shared memory layout - using SmemLayoutB = - layout::RowMajorTensorOpMultiplicandCongruous::value, - int(128 / sizeof(half_t))>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = float; - using LayoutA = layout::RowMajor; - using ElementB = float; - using LayoutB = layout::ColumnMajor; - using ElementC = float; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 256; - - /// Default Operator - using Operator = arch::OpMultiplyAdd; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = - layout::RowMajorTensorOpMultiplicandCrosswise::value, - Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutB, - 1, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = float; - using LayoutA = layout::RowMajor; - using ElementB = float; - using LayoutB = layout::RowMajor; - using ElementC = float; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 256; - - /// Default Operator - using Operator = arch::OpMultiplyAdd; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, int(128 / sizeof(half_t))>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutA, - 0, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - half_t, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = float; - using LayoutA = layout::ColumnMajor; - using ElementB = float; - using LayoutB = layout::ColumnMajor; - using ElementC = float; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 256; - - /// Default Operator - using Operator = arch::OpMultiplyAdd; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, int(128 / sizeof(half_t))>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, half_t, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, half_t, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, MatrixShape<0, 0>, - WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major-interleave -/// B: row-major-interleave -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -/// -/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -/// can be reused. The shared store iterator is the same as the crosswise shared -/// store iterator. So, the only thing we need to do is to swap the coordinates -/// (contiguous <=> strided) used by the global iterator and the shared store -/// iterator. -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Number of interleaved k - int InterleavedK> -struct DefaultMmaCore, ElementB_, - layout::RowMajorInterleaved, ElementC_, - LayoutC_, arch::OpClassTensorOp, 2, Operator_, - AccumulatorsInRowMajor> { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajorInterleaved; - using ElementB = ElementB_; - using LayoutB = layout::RowMajorInterleaved; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassTensorOp; - static int const kInterleavedK = InterleavedK; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = - kAccessSizeInBits / sizeof_bits::value; - - static int const kWarpThreadArrangementContiguous = - kInterleavedK / kElementsPerAccess; - - static int const kWarpThreadArrangementStrided = - kWarpSize / kWarpThreadArrangementContiguous; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kInterleavedK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kInterleavedK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMap< - IteratorThreadMapA, - layout::PitchLinearShape>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMap< - IteratorThreadMapB, - layout::PitchLinearShape>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h deleted file mode 100644 index b5e14c6ad20e063078b838a6ed55bc04fde0d5c4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ /dev/null @@ -1,2951 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming - expectations about data layout of the global memory fragments, data types, - and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp - instructions. - - SM80 Multi stage kernel expects stage number to be larger or equal to 3 - to use asynchronous copy. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -#include "cutlass/gemm/threadblock/default_mma_core.h" -#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" -#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -#include "cutlass/gemm/threadblock/mma_multistage.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for double-precision -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::ColumnMajor; - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 64; - - /// Default Operator - using Operator = Operator_; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; - - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -/// Partial specialization for double-precision -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::ColumnMajor; - using ElementB = double; - using LayoutB = layout::RowMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 64; - - /// Default Operator - using Operator = Operator_; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; - - // Shared memory layout - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for double-precision -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::RowMajor; - using ElementB = double; - using LayoutB = layout::ColumnMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 64; - - /// Default Operator - using Operator = Operator_; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; - - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Partial specialization for double-precision -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::RowMajor; - using ElementB = double; - using LayoutB = layout::RowMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 64; - - /// Default Operator - using Operator = Operator_; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; - - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for double-precision -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = double; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -/// Partial specialization for double-precision -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = double; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for double-precision -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = double; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Partial specialization for double-precision -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = double; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = double; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = double; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float-precision -/// -/// ElementA: complex -/// ElementB: complex -/// ElementC: complex -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Layout for A operand - typename LayoutA_, - /// Layout for B operand - typename LayoutB_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// per-element transformation for elements of A - ComplexTransform TransformA_, - /// per-element transformation for elements of B - ComplexTransform TransformB_ - > -struct DefaultMmaCore< - Shape_, WarpShape_, GemmShape<16, 8, 8>, - complex, LayoutA_, - complex, LayoutB_, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - Operator_, - false, - CacheOpA, - CacheOpB, - TransformA_, TransformB_, true> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<16, 8, 8>; - using ElementA = complex; - using LayoutA = LayoutA_; - using ElementB = complex; - using LayoutB = LayoutB_; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - static const ComplexTransform TransformA = TransformA_; - static const ComplexTransform TransformB = TransformB_; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - static_assert( - platform::is_same::value || - platform::is_same::value || - platform::is_same::value, - "The operator tag must indicate complex multiplication."); - - // - // Underlying template - // - - using MmaComplexCore = DefaultMultistageMmaComplexCore< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - arch::OpClassTensorOp, - kStages, - TransformA, - TransformB, - Operator, - kCacheOpA, - kCacheOpB - >; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; - - // Shared memory layout - using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; - - /// ThreadMap of iterator B - using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; - - /// Policy used to define MmaPipelined - using MmaPolicy = typename MmaComplexCore::MmaPolicy; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for double-precision -/// -/// ElementA: complex -/// ElementB: complex -/// ElementC: complex -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout for A operand - typename LayoutA_, - /// Layout for B operand - typename LayoutB_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// per-element transformation for elements of A - ComplexTransform TransformA_, - /// per-element transformation for elements of B - ComplexTransform TransformB_ - > -struct DefaultMmaCore< - Shape_, WarpShape_, InstructionShape_, - complex, LayoutA_, - complex, LayoutB_, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - Operator_, - false, - CacheOpA, - CacheOpB, - TransformA_, TransformB_, true> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = complex; - using LayoutA = LayoutA_; - using ElementB = complex; - using LayoutB = LayoutB_; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - static const ComplexTransform TransformA = TransformA_; - static const ComplexTransform TransformB = TransformB_; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 64; - - /// Default Operator - using Operator = Operator_; - - static_assert( - platform::is_same::value || - platform::is_same::value, - "The operator tag must indicate complex multiplication."); - - // - // Underlying template - // - - using MmaComplexCore = DefaultMultistageMmaComplexCore< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - arch::OpClassTensorOp, - kStages, - TransformA, - TransformB, - Operator, - kCacheOpA, - kCacheOpB - >; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; - - // Shared memory layout - using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; - - /// ThreadMap of iterator B - using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; - - /// Policy used to define MmaPipelined - using MmaPolicy = typename MmaComplexCore::MmaPolicy; -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; - - // Shared memory layout - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major-interleaved -/// B: row-major-interleaved -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -/// -/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -/// can be reused. The shared store iterator is the same as the crosswise shared -/// store iterator. So, the only thing we need to do is to swap the coordinates -/// (contiguous <=> strided) used by the global iterator and the shared store -/// iterator. -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Number of interleaved K - int InterleavedK> -struct DefaultMmaCore, ElementB_, - layout::RowMajorInterleaved, ElementC_, - LayoutC_, arch::OpClassTensorOp, Stages, Operator_, - AccumulatorsInRowMajor, CacheOpA, CacheOpB> { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajorInterleaved; - using ElementB = ElementB_; - using LayoutB = layout::RowMajorInterleaved; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - static int const kInterleavedK = InterleavedK; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = - kAccessSizeInBits / sizeof_bits::value; - - static int const kWarpThreadArrangementContiguous = - kInterleavedK / kElementsPerAccess; - - static int const kWarpThreadArrangementStrided = - kWarpSize / kWarpThreadArrangementContiguous; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kInterleavedK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kInterleavedK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMap< - IteratorThreadMapA, - layout::PitchLinearShape>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapB = transform::TransposePitchLinearThreadMap< - IteratorThreadMapB, - layout::PitchLinearShape>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - // Shared memory layout - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static_assert(!((Shape::kK / 32) % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, Shape::kK / 32>, - WarpCount::kK>; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - // Shared memory layout - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK>; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - // Shared memory layout - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static_assert(!((Shape::kK / 32) % LaneM) && !((Shape::kK / 32) % LaneN), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, Shape::kK / 32>, - WarpCount::kK>; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - // Shared memory layout - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - - static_assert(!((Shape::kK / 32) % LaneM), - "Padding must be divisible by Lane"); - - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, 0>, - WarpCount::kK>; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; - -}; - -/// Partial specialization for SIMT GEMMs using multistage pipeline. -/// -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by Simt - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::AffineRank2RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::AffineRank2RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Default Operator - using Operator = Operator_; - - using Base = DefaultMmaCore; - - // - // Shared memory layouts - // - - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - - /// Shared memory iterator to A operand - using SmemIteratorA = typename Base::SmemIteratorA; - - /// Policy of iterator B - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - - /// Shared memory iterator to B operand - using SmemIteratorB = typename Base::SmemIteratorB; - - // - // Warp-level matrix multiply operator - // - - /// Policy used to define MmaPipelined - using MmaPolicy = typename Base::MmaPolicy; - -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h deleted file mode 100644 index 4abf72352ba0d37441126be0ce2e0a6f12f0e0d6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h +++ /dev/null @@ -1,876 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming - expectations about data layout of the global memory fragments, data types, - and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting sparse - TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" - -#include "cutlass/gemm/threadblock/default_mma_core.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -#include "cutlass/gemm/threadblock/mma_sparse_multistage.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Template defininng default matrix multiply operators inferred from threadblock tile size, -/// global memory data layout, and target math instruction. -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false - /// Cache operation of operand A - , cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global -> -struct DefaultSparseMmaCore; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultSparseMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - static int const kSparse = 2; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Cache operation of operand E - static cutlass::arch::CacheOperation::Kind const kCacheOpE = - cutlass::arch::CacheOperation::Global; - - static int const kInterleavedE = MmaTensorOp::kInterleaved; - static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; - static int const kMaxID2 = MmaTensorOp::kMaxID2; - static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; - - using ElementE = typename MmaTensorOp::ElementE; - using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; - - // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. - using SmemLayoutE = typename MmaTensorOp::LayoutE; - - /// ThreadMap of iterator E - static int const kElementsPerAccessE = - kAccessSizeInBits / sizeof_bits::value; - - /// E is tiny. Not all warps are needed. - static int const kThreadsE = - (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value) > - kThreads) - ? kThreads - : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value)); - - using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreadsE, kElementsPerAccessE>; - - /// Shared memory iterator to E operand - using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementE, SmemLayoutE, 0, IteratorThreadMapE>; - - /// Policy used to define MmaPipelined - using MmaPolicy = - SparseMmaPolicy, MatrixShape<0, 0>, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultSparseMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - static int const kSparse = 2; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - // crosswise cannot be larger than 1024 bit. - static int const kCrosswiseB = - (Shape::kK > (1024 / sizeof_bits::value)) - ? (1024 / sizeof_bits::value) - : Shape::kK; - - static int const kWarpThreadArrangementContiguousB = - kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK / kSparse>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kCrosswiseB>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Cache operation of operand E - static cutlass::arch::CacheOperation::Kind const kCacheOpE = - cutlass::arch::CacheOperation::Global; - - static int const kInterleavedE = MmaTensorOp::kInterleaved; - static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; - static int const kMaxID2 = MmaTensorOp::kMaxID2; - static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; - - using ElementE = typename MmaTensorOp::ElementE; - using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; - - // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. - using SmemLayoutE = typename MmaTensorOp::LayoutE; - - /// ThreadMap of iterator E - static int const kElementsPerAccessE = - kAccessSizeInBits / sizeof_bits::value; - - /// E is tiny. Not all warps are needed. - static int const kThreadsE = - (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value) > - kThreads) - ? kThreads - : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value)); - - using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreadsE, kElementsPerAccessE>; - - - /// Shared memory iterator to E operand - using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementE, SmemLayoutE, 0, IteratorThreadMapE>; - - /// Policy used to define MmaPipelined - using MmaPolicy = - SparseMmaPolicy, MatrixShape<0, 0>, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultSparseMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - static int const kSparse = 2; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), - Shape::kM); - - static int const kWarpThreadArrangementContiguousA = - platform::min(Shape::kM / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - // Warp thread arrangement - // crosswise cannot be larger than 1024 bit. - static int const kCrosswiseB = - (Shape::kK > (1024 / sizeof_bits::value)) - ? (1024 / sizeof_bits::value) - : Shape::kK; - - static int const kWarpThreadArrangementContiguousB = - kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_A>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kCrosswiseB>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Cache operation of operand E - static cutlass::arch::CacheOperation::Kind const kCacheOpE = - cutlass::arch::CacheOperation::Global; - - static int const kInterleavedE = MmaTensorOp::kInterleaved; - static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; - static int const kMaxID2 = MmaTensorOp::kMaxID2; - static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; - - using ElementE = typename MmaTensorOp::ElementE; - using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; - - // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. - using SmemLayoutE = typename MmaTensorOp::LayoutE; - - /// ThreadMap of iterator E - static int const kElementsPerAccessE = - kAccessSizeInBits / sizeof_bits::value; - - /// E is tiny. Not all warps are needed. - static int const kThreadsE = - (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value) > - kThreads) - ? kThreads - : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value)); - - using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreadsE, kElementsPerAccessE>; - - /// Shared memory iterator to E operand - using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementE, SmemLayoutE, 0, IteratorThreadMapE>; - - /// Policy used to define MmaPipelined - using MmaPolicy = - SparseMmaPolicy, MatrixShape<0, 0>, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultSparseMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - static int const kSparse = 2; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), - Shape::kN); - - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK / kSparse>; - - // Shared memory layout - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise_B>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, WarpCount::kK>::Type; - - /// Cache operation of operand E - static cutlass::arch::CacheOperation::Kind const kCacheOpE = - cutlass::arch::CacheOperation::Global; - - static int const kInterleavedE = MmaTensorOp::kInterleaved; - static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; - static int const kMaxID2 = MmaTensorOp::kMaxID2; - static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; - - using ElementE = typename MmaTensorOp::ElementE; - using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; - - // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. - using SmemLayoutE = typename MmaTensorOp::LayoutE; - - /// ThreadMap of iterator E - static int const kElementsPerAccessE = - kAccessSizeInBits / sizeof_bits::value; - - /// E is tiny. Not all warps are needed. - static int const kThreadsE = - (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value) > - kThreads) - ? kThreads - : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / - (kAccessSizeInBits / sizeof_bits::value)); - - using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreadsE, kElementsPerAccessE>; - - /// Shared memory iterator to E operand - using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementE, SmemLayoutE, 0, IteratorThreadMapE>; - - /// Policy used to define MmaPipelined - using MmaPolicy = - SparseMmaPolicy, MatrixShape<0, 0>, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h deleted file mode 100644 index b260c91197f1a86c2521778527aa7d13791f7327..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h +++ /dev/null @@ -1,328 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting simt instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/warp/mma.h" -#include "cutlass/gemm/threadblock/mma_pipelined.h" -#include "cutlass/gemm/threadblock/mma_singlestage.h" -#include "cutlass/arch/cache_operation.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Size of a threadblock-scoped access - int kAccessSizeInBits = -1, // -1 denoting the default - /// Number of stages - int Stages = 2, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global, - /// per-element transformation for elements of A - ComplexTransform TransformA = ComplexTransform::kNone, - /// per-element transformation for elements of B - ComplexTransform TransformB = ComplexTransform::kNone, - bool IsComplex = false // (is_complex::value || is_complex::value) -> -struct DefaultMmaCoreWithAccessSize; - -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Number of stages - int Stages, - /// Operation performed by MMA - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// per-element transformation for elements of A - ComplexTransform TransformA, - /// per-element transformation for elements of B - ComplexTransform TransformB, - bool IsComplex -> -struct DefaultMmaCoreWithAccessSize< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - OperatorClass, -1, Stages, Operator, AccumulatorsInRowMajor, - CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -> : DefaultMmaCore< - Shape, WarpShape, InstructionShape, - ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - OperatorClass, Stages, Operator, AccumulatorsInRowMajor, - CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -> {}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: simt class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Size of a threadblock-scoped access (a value of -1 indicates the default) - int kAccessSizeInBits_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCoreWithAccessSize>::type, ElementA_, - layout::ColumnMajor, ElementB_, layout::RowMajor, - ElementC_, LayoutC_, arch::OpClassSimt, kAccessSizeInBits_, 2, Operator_ - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - static int const PartitionsK = Shape::kK / WarpShape::kK; - - /// Default Operator - using Operator = Operator_; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - PartitionsK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - static int const kElementsPerAccessDefault = 1; - static_assert(kAccessSizeInBits_ == -1 || - sizeof_bits::value == sizeof_bits::value || - kAccessSizeInBits_ / sizeof_bits::value == kElementsPerAccessDefault, - "Non-default value for kAccessSizeInBits_ is only allowed if size(elementA) == sizeof(elementB)"); - static int const kElementsPerAccess = (kAccessSizeInBits_ != -1) ? kAccessSizeInBits_ / sizeof_bits::value : kElementsPerAccessDefault; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); - static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h deleted file mode 100644 index 72015956e905561b5f4be686dbeea2921b7ba3df..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h +++ /dev/null @@ -1,167 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming - expectations about data layout of the global memory fragments, data types, - and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp - instructions. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -#include "cutlass/gemm/threadblock/default_mma_core.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -#include "cutlass/gemm/threadblock/mma_with_reduction_multistage.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Template defininng default matrix multiply operators inferred from threadblock tile size, -/// global memory data layout, and target math instruction. -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape_, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Reduce operand A or B along K dimension - bool ReduceKForA_, - /// Number of stages - int Stages = 2, - /// Operation performed by MMA - typename Operator = typename platform::conditional< - (platform::is_same::value) && - (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global, - /// per-element transformation for elements of A - ComplexTransform TransformA = ComplexTransform::kNone, - /// per-element transformation for elements of B - ComplexTransform TransformB = ComplexTransform::kNone, - bool IsComplex = false// (is_complex::value || is_complex::value) -> -struct DefaultMmaWithReductionCore { - using Base = DefaultMmaCore; - using Shape = Shape_; - using IteratorThreadMapA = typename Base::IteratorThreadMapA; - using IteratorThreadMapB = typename Base::IteratorThreadMapB; - using SmemIteratorA = typename Base::SmemIteratorA; - using SmemIteratorB = typename Base::SmemIteratorB; - using SmemLayoutA = typename Base::SmemLayoutA; - using SmemLayoutB = typename Base::SmemLayoutB; - using WarpCount = typename Base::WarpCount; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp< - WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, - ElementC, LayoutC, Operator, ReduceKForA_, WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h deleted file mode 100644 index 7b3bbcf71ed389cc7f001bb943ce70c62a83dd5d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h +++ /dev/null @@ -1,712 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" - -#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// Operator: wmma tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - ///< Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_, - /// Number of stages - int Stages> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassWmmaTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // - // Shared memory layouts - // - // NOTE: shared memory layout for wmma is same as the operands' layout in the global memory - using SmemLayoutA = LayoutA; - using SmemLayoutB = LayoutB; - - // Pad shared memory to avoid bank conflicts - static int const kPaddingA = 128 / sizeof_bits::value; - static int const kPaddingB = 128 / sizeof_bits::value; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Wmma< - InstructionShape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape, - MatrixShape<0, kPaddingB>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: column-major -/// Operator: wmma tensorop class -/// -/// This uses the default warp-level operator given tile sizes -template < - ///< Shape of threadblock-scoped matrix multiply operator - ///< (concept:GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) [allowed - /// wmma instruction shapes, e.g., 16x16x16, 32x8x16, 8x32x16,...] - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_, - /// Number of stages - int Stages> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassWmmaTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads per threadblock - static int const kThreads = WarpCount::kCount * kWarpSize; - - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - // shared memory layout for wmma is same as the operands' layout in global memory - using SmemLayoutA = LayoutA; - using SmemLayoutB = LayoutB; - - // Pad shared memory to avoid bank conflicts - static int const kPaddingA = 128 / sizeof_bits::value; - static int const kPaddingB = 128 / sizeof_bits::value; - - // - // Iterators to write to shared memory - // - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB // SmemThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Wmma< - InstructionShape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, kPaddingA>, - MatrixShape, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: row-major -/// B: row-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_, - /// Number of stages - int Stages> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::RowMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassWmmaTensorOp; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - // - // Shared memory layouts - // - - // shared memory layout for wmma is same as the operands' layout in global memory - using SmemLayoutA = LayoutA; - using SmemLayoutB = LayoutB; - - // Pad shared memory to avoid bank conflicts - static int const kPaddingA = 128 / sizeof_bits::value; - static int const kPaddingB = 128 / sizeof_bits::value; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Wmma< - InstructionShape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, kPaddingA>, - MatrixShape<0, kPaddingB>, - WarpCount::kK - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: column-major -/// Operator: tensor op class -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by MMA - typename Operator_, - /// Number of stages - int Stages> -struct DefaultMmaCore { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::ColumnMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassWmmaTensorOp; - - /// Number of warps present - using WarpCount = - GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - /// Default Operator - using Operator = Operator_; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - // shared memory layout for wmma is same as the operands' layout in global memory - using SmemLayoutA = LayoutA; - using SmemLayoutB = LayoutB; - - // Pad shared memory to avoid bank conflicts - static int const kPaddingA = 128 / sizeof_bits::value; - static int const kPaddingB = 128 / sizeof_bits::value; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kAccessSizeInBits / sizeof_bits::value - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Wmma< - InstructionShape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator - >, - cutlass::MatrixShape<1, 1> - >; - - using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Policy - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape, - MatrixShape, - WarpCount::kK - >; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h deleted file mode 100644 index bce17dd19fab25040ab4be1c9e31421842637b79..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h +++ /dev/null @@ -1,178 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" -#include "cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for Scale/Bias vectors - typename ElementScaleBias, - /// Layout type for Scale/Bias vectors - typename LayoutScaleBias, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Use zfill or predicate for SM80 out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone - > -struct DefaultMmaLayernormMainloopFusion { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - /// Define iterators over tiles from scale/bias vectors - using IteratorVarMean = - cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< - cutlass::MatrixShape<1, WarpShape::kN>, - ElementScaleBias, - LayoutScaleBias>; - - /// Define iterators over tiles from scale/bias vectors - using IteratorGammaBeta = - cutlass::transform::threadblock::PredicatedScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - using SmemIteratorGammaBeta = - cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< - cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, - LayoutScaleBias>; - - static int const kThreadCount = 32; - - // Warp-level iterators to load scale and bias vectors - using WarpIteratorGammaBeta = cutlass::gemm::warp::ScaleBiasTileIterator< - MatrixShape, ElementScaleBias, - LayoutScaleBias, MatrixShape, - typename MmaCore::MmaTensorOp::IteratorA::Base::Policy, kThreadCount, - MmaCore::WarpCount::kK>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaLayernormMainloopFusionMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, IteratorVarMean, IteratorGammaBeta, SmemIteratorGammaBeta, - CacheOpGammaBeta, - ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, WarpIteratorGammaBeta, Stages, SharedMemoryClear>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h deleted file mode 100644 index cab385aff88f9b4736da33de2819b19a2f9f0f9e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h +++ /dev/null @@ -1,136 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass/gemm/threadblock/mma_planar_complex_multistage.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Math operator tag (e.g. arch::OpMultiplyAdd) - typename Operator = arch::OpMultiplyAdd -> -struct DefaultMmaPlanarComplexMultistage { - - // Construct a planar complex variant from the real-valued variant - using RealMmaMultistage = typename DefaultMma< - ElementA_, - LayoutA_, - kAlignmentA, - ElementB_, - LayoutB_, - kAlignmentB, - ElementAccumulator_, - LayoutC_, - OperatorClass_, - ArchTag_, - ThreadblockShape_, - WarpShape_, - InstructionShape_, - Stages, - Operator - >::ThreadblockMma; - - using ThreadblockMma = MmaPlanarComplexMultistage< - ThreadblockShape_, - typename RealMmaMultistage::IteratorA, - typename RealMmaMultistage::SmemIteratorA, - cutlass::arch::CacheOperation::Global, - typename RealMmaMultistage::IteratorB, - typename RealMmaMultistage::SmemIteratorB, - cutlass::arch::CacheOperation::Global, - ElementAccumulator_, - LayoutC_, - typename RealMmaMultistage::Policy, - Stages, - TransformA, - TransformB - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h deleted file mode 100644 index 51327c1a382cfff194741d32cdcfcf32d2dca5b8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h +++ /dev/null @@ -1,130 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" - -#include "cutlass/gemm/warp/mma_planar_complex.h" -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Math operator tag (e.g. arch::OpMultiplyAdd) - typename Operator = arch::OpMultiplyAdd -> -struct DefaultMmaPlanarComplexPipelined { - - // Construct a planar complex variant from the real-valued variant - using RealMma = typename DefaultMma< - ElementA_, - LayoutA_, - kAlignmentA, - ElementB_, - LayoutB_, - kAlignmentB, - ElementAccumulator_, - LayoutC_, - OperatorClass_, - ArchTag_, - ThreadblockShape_, - WarpShape_, - InstructionShape_, - Stages, - Operator - >::ThreadblockMma; - - using ThreadblockMma = MmaPlanarComplexPipelined< - ThreadblockShape_, - typename RealMma::IteratorA, - typename RealMma::SmemIteratorA, - typename RealMma::IteratorB, - typename RealMma::SmemIteratorB, - ElementAccumulator_, - LayoutC_, - typename RealMma::Policy, - Stages, - TransformA, - TransformB - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h deleted file mode 100644 index c8c6cf7e248435bd5d931d23d837b4ea41b145bf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h +++ /dev/null @@ -1,160 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined softmax-GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" -#include "cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for Scale/Bias vectors - typename ElementScaleBias, - /// Layout type for Scale/Bias vectors - typename LayoutScaleBias, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Whether problem has been transformed. This determines to which operand - /// the softmax is applied. - bool InternalTranspose, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Use zfill or predicate for SM80 out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone - > -struct DefaultMmaSoftmaxMainloopFusion { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - /// Define iterators over tiles from scale/bias vectors - using IteratorNormSum = - cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< - cutlass::MatrixShape<1, WarpShape::kN>, - ElementScaleBias, - LayoutScaleBias>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaSoftmaxMainloopFusionMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, IteratorNormSum, - ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, InternalTranspose, SharedMemoryClear>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h deleted file mode 100644 index ae1ac25346bec4339815cf5eb25f6d83e9e836a6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h +++ /dev/null @@ -1,141 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Operator class tag - typename OperatorClass, - /// - bool ReduceKForA_, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Use zfill or predicate for SM80 out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone - > -struct DefaultMmaWithReduction { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaWithReductionCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - ReduceKForA_, Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaWithReductionMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h deleted file mode 100644 index 62d0c49b338e09a21efb8148b9418ff200ee9dc7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h +++ /dev/null @@ -1,159 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator = arch::OpMultiplyAddComplex, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false> -struct DefaultMultistageMmaComplex; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageMmaComplex { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages>; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h deleted file mode 100644 index 8751495a58c5b403b67a43f7dedf16a39615bd3a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h +++ /dev/null @@ -1,119 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming - expectations about data layout of the global memory fragments, data types, - and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp - instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/complex.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -#include "cutlass/gemm/threadblock/default_mma_core.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/pitch_linear_thread_map.h" - -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Template defininng default matrix multiply operators inferred from -/// threadblock tile size, global memory data layout, and target math -/// instruction. -template < - /// Shape of threadblock-scoped matrix multiply operator - typename Shape, - /// Shape of warp-level matrix multiply operator - typename WarpShape, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape, - /// Element data type of A operand - typename ElementA, - /// Layout of operand A - typename LayoutA, - /// Element data type of B operand - typename ElementB, - /// Layout of operand B - typename LayoutB, - /// Data type of accumulator - typename ElementC, - /// Layout of accumulator - typename LayoutC, - /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) - typename OperatorClass, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator = arch::OpMultiplyAddComplex, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA = - cutlass::arch::CacheOperation::Global, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB = - cutlass::arch::CacheOperation::Global> -struct DefaultMultistageMmaComplexCore; - - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h deleted file mode 100644 index f9716f324fd9ee12ff1b7e0dd508d77c766514f8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h +++ /dev/null @@ -1,1808 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming - expectations about data layout of the global memory fragments, data types, - and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp - instructions. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/gemm/warp/mma_simt_policy.h" -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -#include "cutlass/gemm/threadblock/mma_multistage.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex double-precision -/// -/// A: column-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, InstructionShape_, - complex, layout::ColumnMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped 128 - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - - -/// Partial specialization for complex double-precision -/// -/// A: column-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, InstructionShape_, - complex, layout::ColumnMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - using Operator = Operator_; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped 128 - static int const kAccessSizeInBits = 128; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex double-precision -/// -/// A: row-major -/// B: column-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, InstructionShape_, - complex, layout::RowMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped 128 - static int const kAccessSizeInBits = 128; - - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - - -/// Partial specialization for complex double-precision -/// -/// A: row-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, InstructionShape_, - complex, layout::RowMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped 128 - static int const kAccessSizeInBits = 128; - - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<8, 4>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex floating-point -/// -/// A: column-major -/// B: column-major -/// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<16, 8, 8>, - complex, layout::ColumnMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<16, 8, 8>; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped - static int const kAccessSizeInBits = 64; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; - - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - - -/// Partial specialization for complex floating-point -/// -/// A: column-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<16, 8, 8>, - complex, layout::ColumnMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<16, 8, 8>; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped - static int const kAccessSizeInBits = 64; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex floating-point -/// -/// A: row-major -/// B: column-major -/// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<16, 8, 8>, - complex, layout::RowMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<16, 8, 8>; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped - static int const kAccessSizeInBits = 64; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; - - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex floating-point -/// -/// A: row-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<16, 8, 8>, - complex, layout::RowMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassTensorOp, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<16, 8, 8>; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped - static int const kAccessSizeInBits = 64; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; - - using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 1, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< - layout::PitchLinearShape, kThreads, - layout::PitchLinearShape<16, 2>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 0, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< - WarpShape, InstructionShape, - ElementA, SmemLayoutA, - ElementB, SmemLayoutB, - ElementC, LayoutC, - kTransformA, kTransformB, - Operator>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex SIMT operation -/// -/// A: column-major -/// B: column-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - typename RealA, - typename RealB, - typename RealC, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<1, 1, 1>, - complex, layout::ColumnMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassSimt, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of access - static int const kAccessSizeInBits = sizeof_bits::value; - - /// No vectorized accesses - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - 1, /// 1 partition along K dimension - kTransformA, /// Transform for A - kTransformB /// Transform for B - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, Shape::kK / 32>, - WarpCount::kK>; -}; - -/// Partial specialization for complex SIMT operation -/// -/// A: column-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - typename RealA, - typename RealB, - typename RealC, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<1, 1, 1>, - complex, layout::ColumnMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassSimt, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = complex; - using LayoutA = layout::ColumnMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of access - static int const kAccessSizeInBits = sizeof_bits::value; - - /// No vectorized accesses - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - IteratorThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - 1, /// 1 partition along K dimension - kTransformA, /// Transform for A - kTransformB /// Transform for B - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape<0, 0>, - MatrixShape<0, 0>, // or Shape::kK / 32 - WarpCount::kK>; -}; - -/// Partial specialization for complex SIMT operation -/// -/// A: row-major -/// B: column-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - typename RealA, - typename RealB, - typename RealC, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<1, 1, 1>, - complex, layout::RowMajor, - complex, layout::ColumnMajor, - complex, LayoutC_, - arch::OpClassSimt, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::ColumnMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of access - static int const kAccessSizeInBits = sizeof_bits::value; - - /// No vectorized accesses - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator B - using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - SmemThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - 1, /// 1 partition along K dimension - kTransformA, /// Transform for A - kTransformB /// Transform for B - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, Shape::kK / 32>, - WarpCount::kK>; -}; - -/// Partial specialization for complex SIMT operation -/// -/// A: row-major -/// B: row-major -/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - typename RealA, - typename RealB, - typename RealC, - /// Layout of accumulator - typename LayoutC_, - /// Number of stages - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_, - /// Cache operation of operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Cache operation of operand B - cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<1, 1, 1>, - complex, layout::RowMajor, - complex, layout::RowMajor, - complex, LayoutC_, - arch::OpClassSimt, - Stages, - TransformA, TransformB, - Operator_, - CacheOpA, CacheOpB> { - - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = GemmShape<1, 1, 1>; - using ElementA = complex; - using LayoutA = layout::RowMajor; - using ElementB = complex; - using LayoutB = layout::RowMajor; - using ElementC = complex; - using LayoutC = LayoutC_; - static int const kStages = Stages; - static ComplexTransform const kTransformA = TransformA; - static ComplexTransform const kTransformB = TransformB; - using Operator = Operator_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; - - /// Number of warps present - using WarpCount = GemmShape; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - static_assert(WarpCount::kCount > 1, - "This specialization requires at least two warps."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of access - static int const kAccessSizeInBits = sizeof_bits::value; - - /// No vectorized accesses - static int const kElementsPerAccess = 1; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::ColumnMajor; - - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Transpose the ThreadMap of iterator A - using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, - SmemThreadMapA>; - - /// Policy of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - kElementsPerAccess - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level op - static const int WarpNumThreadsM = 4; - static const int WarpNumThreadsN = 8; - static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), - "WarpShape must be divisible by ThreadTile shape."); - static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; - static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; - static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; - static const int numElementsA = 128 / sizeof_bits::value; - static const int numElementsB = 128 / sizeof_bits::value; - static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); - static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); - // these should have max of thread tile also - using LaneMmaShape = cutlass::gemm::GemmShape< - LaneM, - LaneN, - 1>; - using Policy = cutlass::gemm::warp::MmaSimtPolicy< - cutlass::MatrixShape, // WarpShape - cutlass::layout::RowMajorInterleaved, // LaneLayout - LaneMmaShape - >; - - using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - 1, /// 1 partition along K dimension - kTransformA, /// Transform for A - kTransformB /// Transform for B - >; /// Used for partial specialization - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaWarpSimt, - MatrixShape, - MatrixShape<0, 0>, // or Shape::kK / 32 - WarpCount::kK>; -}; - -//////////////////////////////////////////////////////////////////////////////// - - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h deleted file mode 100644 index 4045dd2e4173c072b359bfccf0e4c48f6c15146d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h +++ /dev/null @@ -1,556 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/arch/arch.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator = arch::OpMultiplyAddComplex, - /// Blas3 computation mode - BlasMode BlasMode_ = BlasMode::kTriangular, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false> -struct DefaultMultistageTrmmComplex; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - kSideMode, kFillMode, kDiagType, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - kSideMode, FillMode::kFull, DiagType::kInvalid, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output and right-side mode -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - SideMode::kRight, FillMode::kFull, DiagType::kInvalid, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - SideMode::kRight, kFillMode, kDiagType, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output with unit diagonal -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - kSideMode, kFillMode, DiagType::kUnit, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - kSideMode, FillMode::kFull, DiagType::kInvalid, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output and right-side mode, unit diagonal -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - SideMode::kRight, FillMode::kFull, DiagType::kInvalid, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - SideMode::kRight, kFillMode, DiagType::kUnit, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (for TRMM where diagonal imag part is ignored - used by HEMM) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, - // when DiagType is kUnit - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - kSideMode, kFillMode, DiagType::kUnit, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - kSideMode, FillMode::kFull, DiagType::kInvalid, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, - BlasMode::kHermitian>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output and right-side mode (for TRMM where diagonal imag part is ignored - used by HEMM) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Complex transformation on operand A - ComplexTransform TransformA, - /// Complex transformation on operand B - ComplexTransform TransformB, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator> -struct DefaultMultistageTrmmComplex { - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, - Stages, TransformA, TransformB, Operator>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, - SideMode::kRight, FillMode::kFull, DiagType::kInvalid, - AccessTypeA>; - - // Define iterators over tiles from the B operand - // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, - // when DiagType is kUnit - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, - SideMode::kRight, kFillMode, DiagType::kUnit, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, - BlasMode::kHermitian>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h deleted file mode 100644 index 3c8632c8f4a109df6d5b1f80903cb1dbdb34122e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h +++ /dev/null @@ -1,196 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/wmma.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false - > -struct DefaultSparseMma; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultSparseMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - static int const kSparse = MmaCore::kSparse; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define iterators over tiles from the E operand - using ElementE = typename MmaCore::ElementE; - using LayoutE = typename MmaCore::GmemLayoutE; - using ThreadMapE = typename MmaCore::IteratorThreadMapE; - using AccessTypeE = - cutlass::Array::value>; - using IteratorE = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE, - typename MmaCore::MmaPolicy, Stages>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_trmm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_trmm.h deleted file mode 100644 index 066ecd6aa4cf6137f78b0ee502053f59d1d18354..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/default_trmm.h +++ /dev/null @@ -1,445 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -// -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/wmma.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) -#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false - > -struct DefaultTrmm; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultTrmm { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, kDiagType, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output, right side mode (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Diag Type for the triangular matrix - DiagType kDiagType, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultTrmm { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, kDiagType, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output with unit diagonal (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Side Mode for the kernel - SideMode kSideMode, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultTrmm { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, DiagType::kUnit, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output, right side mode, unit diagonal (OperatorClass TensorOp) -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Fill Mode for the triangular matrix - FillMode kFillMode, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Number of stages used in the multistage mainloop - int Stages, - /// Operation performed by GEMM - typename Operator - > -struct DefaultTrmm { - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, - Stages, Operator, false, CacheOpA, CacheOpB>; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, DiagType::kUnit, AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h deleted file mode 100644 index 83723619e8494c138bd0d17cb91b09fbfff27b39..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h +++ /dev/null @@ -1,648 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a multistage threadblock-scoped Blocked-Ell MMA. -*/ - -#pragma once - - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class EllMmaMultistage : - public MmaBase { -public: - ///< Base class - using Base = MmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - EllMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - template - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, EllIterator &ell_iter, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - bool is_valid = iterator_A.valid(); - - if (!is_A_sparse){ - if (is_offset_constant){ - auto ell_offset = ell_iter.get_offset_fast(); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; - } else { - int k_offset = iterator_A.get_k(); - auto ell_offset = ell_iter.get_offset(k_offset); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; - } - } - - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, is_valid); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - bool is_valid = iterator_B.valid(); - - if (is_A_sparse){ - if (is_offset_constant){ - auto ell_offset = ell_iter.get_offset_fast(); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; - } else { - int k_offset = iterator_B.get_k(); - auto ell_offset = ell_iter.get_offset(k_offset); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; - } - } - - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, is_valid); - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - - /// Perform a threadblock-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum, - EllIterator &ell_iterator - ) { - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - auto gmem_ptr = iterator_A.get(); - bool is_valid = iterator_A.valid(); - - if (!is_A_sparse){ - if (is_offset_constant){ - auto ell_offset = ell_iterator.get_offset_fast(); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; - } else { - int k_offset = iterator_A.get_k(); - auto ell_offset = ell_iterator.get_offset(k_offset); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; - } - } - - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, is_valid); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - auto gmem_ptr = iterator_B.get(); - bool is_valid = iterator_B.valid(); - - if (is_A_sparse){ - if (is_offset_constant){ - auto ell_offset = ell_iterator.get_offset_fast(); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; - } else { - int k_offset = iterator_B.get_k(); - auto ell_offset = ell_iterator.get_offset(k_offset); - is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; - } - } - - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, is_valid); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - ++ell_iterator; - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - if (is_A_sparse){ - iterator_A.ell_add_mask(ell_iterator.get_blocksize()); - } - else { - iterator_B.ell_add_mask(ell_iterator.get_blocksize()); - } - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // tf32x3 kernels use staging accumulation. warp_mma uses a temporary - // accumulator and this temporary accumulator is added to the final - // accumulator once in every mainloop iteration. - plus plus_accum; - - FragmentC tmp_accum; - - if (platform::is_same::value - || platform::is_same::value) { - - tmp_accum.clear(); - } - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - if (platform::is_same::value - || platform::is_same::value) { - - warp_mma( - tmp_accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - tmp_accum - ); - - if (warp_mma_k == 0) { - accum = plus_accum(accum, tmp_accum); - tmp_accum.clear(); - } - } else { - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - } - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance( - iterator_A, iterator_B, ell_iterator, group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance( - iterator_A, iterator_B, ell_iterator, group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - ++ell_iterator; - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - } - - } - - if (platform::is_same::value - || platform::is_same::value) { - accum = plus_accum(accum, tmp_accum); - } - - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h deleted file mode 100644 index adcff38d23b8bd527284333253b5d54659808c8f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h +++ /dev/null @@ -1,376 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped Blocked-Ell MMA. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Transformation applied to A operand - typename TransformA_ = NumericArrayConverter< - typename SmemIteratorA_::Element, - typename IteratorA_::Element, - IteratorA_::Fragment::kElements>, - /// - /// Transformation applied to B operand - typename TransformB_ = NumericArrayConverter< - typename SmemIteratorB_::Element, - typename IteratorB_::Element, - IteratorB_::Fragment::kElements>, - /// Used for partial specialization - typename Enable = bool -> -class EllMmaPipelined : public MmaBase { -public: - - ///< Base class - using Base = MmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - using TransformA = TransformA_; - using TransformB = TransformB_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for EllMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages==2), "EllMmaPipelined requires kStages set to value 2"); - -private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - -protected: - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - EllMmaPipelined( - typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum, ///< source accumulator tile - EllIterator &ell_iterator, - TransformA transform_A = TransformA(), ///< transformation applied to A fragment - TransformB transform_B = TransformB()) { ///< transformation applied to B fragment - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // load sparse matrix - if (is_A_sparse){ - iterator_A.load(tb_frag_A); - } else { - iterator_B.load(tb_frag_B); - } - - // load dense matrix - if (is_offset_constant){ - if (is_A_sparse){ - iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); - } else { - iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); - } - } else { - if (is_A_sparse){ - iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); - } else { - iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); - } - } - - ++iterator_A; - ++iterator_B; - ++ell_iterator; - - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - if (is_A_sparse){ - iterator_A.ell_add_mask(ell_iterator.get_blocksize()); - } - else { - iterator_B.ell_add_mask(ell_iterator.get_blocksize()); - } - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tightest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, - 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k == 0) { - // load sparse matrix - if (is_A_sparse){ - iterator_A.load(tb_frag_A); - } else { - iterator_B.load(tb_frag_B); - } - - // load dense matrix - if (is_offset_constant){ - if (is_A_sparse){ - iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); - } else { - iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); - } - } else { - if (is_A_sparse){ - iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); - } else { - iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); - } - } - - ++iterator_A; - ++iterator_B; - ++ell_iterator; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - warp_mma(accum, warp_frag_A[warp_mma_k % 2], - warp_frag_B[warp_mma_k % 2], accum); - } - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/gemv.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/gemv.h deleted file mode 100644 index ab747374d8f7b15b65371975379e17b8aee1707f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/gemv.h +++ /dev/null @@ -1,147 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a threadblock-scoped GEMV kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix-vector product using SIMT math instructions. -template < - class Core_ //< GemvCore -> -class Gemv { -public: - using Shape = typename Core_::Shape; - - /// The MMA operator that computes GEMV - using Operator = typename Core_::Operator; - - /// Iterates over A in global memory - using IteratorA = typename Core_::IteratorA; - - /// Iterates over B in global memory - using IteratorB = typename Core_::IteratorB; - - /// Fragment of operand C loaded from global memory - using IteratorC = typename Core_::IteratorC; - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand accumulator loaded/stored to global memory - using FragmentC = typename Operator::FragmentC; - - /// Shape of the per-thread GEMV operation - using ThreadShape = typename Core_::ThreadShape; - -public: - CUTLASS_DEVICE - Gemv() { } - - CUTLASS_DEVICE - void operator()( - GemmCoord const &problem_size, ///< problem size of batched GEMV - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum) { ///< source accumulator tile - - // - // Prologue - // - - FragmentA frag_A; - FragmentB frag_B; - frag_A.clear(); - frag_B.clear(); - - iterator_A.load(frag_A); - iterator_B.load(frag_B); - ++iterator_A; - ++iterator_B; - - // - // Mainloop - // - Operator thread_mma; - int gemm_k = problem_size.k(); - - if (gemm_k < Shape::kK) - { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } - - // iterate over K to accumulate result - CUTLASS_GEMM_LOOP - for (; gemm_k > 0; gemm_k -= Shape::kK) { - thread_mma(accum, frag_A, frag_B, accum); - - iterator_A.load(frag_A); - iterator_B.load(frag_B); - ++iterator_A; - ++iterator_B; - - if (gemm_k < Shape::kK) - { - iterator_A.clear_mask(); - iterator_B.clear_mask(); - } - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/index_remat.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/index_remat.h deleted file mode 100644 index 89e4b1af9c21d115632cd98f20bbc113de3b236b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/index_remat.h +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Helpers for rematerializing indices/dimensions in the thread hierarchy from special registers -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxX() { - return threadIdx.x; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxY() { - return threadIdx.y; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxZ() { - return threadIdx.z; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxX() { - return blockIdx.x; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxY() { - return blockIdx.y; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxZ() { - return blockIdx.z; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimX() { - return blockDim.x; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimY() { - return blockDim.y; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimZ() { - return blockDim.z; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_base.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_base.h deleted file mode 100644 index 2eaa40b707aef310fedc2cb226da1d26d8f0fdb2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_base.h +++ /dev/null @@ -1,236 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/tensor_ref.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Policy object describing MmaTensorOp -template < - /// Warp-level GEMM operator (concept: gemm::warp::Mma) - typename Operator_, - /// Padding used for A operand in shared memory (concept: MatrixShape) - typename SmemPaddingA_, - /// Padding used for B operand in shared memory (concept: MatrixShape) - typename SmemPaddingB_, - /// Number of partitions of K dimension of GEMM - int PartitionsK = 1> -struct MmaPolicy { - /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) - using Operator = Operator_; - - /// Padding used for A operand in shared memory - using SmemPaddingA = SmemPaddingA_; - - /// Padding used for B operand in shared memory - using SmemPaddingB = SmemPaddingB_; - - /// Number of partitions of K dimension - static int const kPartitionsK = PartitionsK; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - static_assert(kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - static_assert((kWarpGemmIterations % 2) == 0, - "Inner loop iteration must be an even number."); - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h deleted file mode 100644 index e94c1de2cb6c17befd8ebd856a503b619bd73be7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h +++ /dev/null @@ -1,707 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. - Used by BLAS3 kernels that need to treat diagonal elements of a input iterator as a special case. - -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kZfill, - /// Blas3 computation mode - BlasMode BlasMode_ = BlasMode::kTriangular, - /// Used for partial specialization - typename Enable = bool> -class MmaBlas3Multistage : - public MmaBase { -public: - ///< Base class - using Base = MmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - ///< Blas Mode - static BlasMode const kBlasMode = BlasMode_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaBlas3Multistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - bool isvalid = iterator_A.valid(); - - if (isvalid && iterator_A.getOnDiag()) { - // Elements that are on diagonal - if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { - /* Copy real part from gmem, write zero for imag part in smem */ - /* The following logic to determine kSizeRealBytes is so that compiler doesn't complain when - * compiling for not complex datatype and using half the size for cp_async_zfill */ - int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, true); - cutlass::arch::cp_async_diag( - reinterpret_cast (dst_ptr + v) + kSizeRealBytes); - } else { - /* Write one (1) directly to smem*/ - cutlass::arch::cp_async_diag(dst_ptr + v); - } - } else { - // Elements that are not of diagonal - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, isvalid); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - bool isvalid = iterator_B.valid(); - - if (isvalid && iterator_B.getOnDiag()) { - // Elements that are on diagonal - if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { - /* Copy real part from gmem, write zero for imag part in smem */ - int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, true); - cutlass::arch::cp_async_diag( - reinterpret_cast (dst_ptr + v) + kSizeRealBytes); - } else { - /* Write one (1) directly to smem*/ - cutlass::arch::cp_async_diag(dst_ptr + v); - } - } else { - // Elements that are not of diagonal - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, isvalid); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - auto gmem_ptr = iterator_A.get(); - bool isvalid = iterator_A.valid(); - - if (isvalid && iterator_A.getOnDiag()) { - // Elements that are on diagonal - if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { - /* Copy real part from gmem, write zero for imag part in smem */ - int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, true); - cutlass::arch::cp_async_diag( - reinterpret_cast (dst_ptr + v) + kSizeRealBytes); - } else { - /* Write one (1) directly to smem*/ - cutlass::arch::cp_async_diag(dst_ptr + v); - } - } else { - // Elements that are not of diagonal - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, isvalid); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - auto gmem_ptr = iterator_B.get(); - bool isvalid = iterator_B.valid(); - - if (isvalid && iterator_B.getOnDiag()) { - // Elements that are on diagonal - if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { - /* Copy real part from gmem, write zero for imag part in smem */ - int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, true); - cutlass::arch::cp_async_diag( - reinterpret_cast (dst_ptr + v) + kSizeRealBytes); - } else { - /* Write one (1) directly to smem*/ - cutlass::arch::cp_async_diag(dst_ptr + v); - } - } else { - // Elements that are not of diagonal - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, isvalid); - } - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // tf32x3 kernels use staging accumulation. warp_mma uses a temporary - // accumulator and this temporary accumulator is added to the final - // accumulator once in every mainloop iteration. - plus plus_accum; - - FragmentC tmp_accum; - - if (platform::is_same::value - || platform::is_same::value) { - - tmp_accum.clear(); - } - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - if (platform::is_same::value - || platform::is_same::value) { - - warp_mma( - tmp_accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - tmp_accum - ); - - if (warp_mma_k == 0) { - accum = plus_accum(accum, tmp_accum); - tmp_accum.clear(); - } - } else { - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - } - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - } - - } - - if (platform::is_same::value - || platform::is_same::value) { - accum = plus_accum(accum, tmp_accum); - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h deleted file mode 100644 index 1f533dde28e4353fc9516344c529e85349db8d09..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h +++ /dev/null @@ -1,863 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. - - It loads two loop invariant vectors, mean and var, in the prologue and - stores them in the register file. In the mainloop, it loads two loop - variant vectors, gamma and beta, by using cp.async. We will call - elementwise operation to apply var, mean, gamma, beta between ldmatrix and - warp mma. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/gemm/warp/layernorm_scale_bias_transform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Element type of scale and bias vectors - typename ElementScaleBias_, - /// Layout of scale and bias vectors - typename LayoutScaleBias_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// WarpIterator to load Scale or Bias vector from the shared memory - typename WarpIteratorGammaBeta_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaMainloopFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Element type of scale and bias vectors - using ElementScaleBias = ElementScaleBias_; - - /// Layout of scale and bias vectors - using LayoutScaleBias = LayoutScaleBias_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< WarpIterator to load Scale or Bias vector from the shared memory - using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = cutlass::gemm::GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the scale and bias vectors - using TensorRefGammaBeta = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the A scale and bias vectors in shared memory - using ShapeGammaBeta = - MatrixShape<1 + Policy::SmemPaddingA::kRow, - 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer for A operand Scale and Bias - AlignedBuffer operand_A_gamma_beta; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a layout object for the A scale and bias vectors - CUTLASS_DEVICE - static LayoutScaleBias LayoutScaleBias() { - return LayoutScaleBias::packed( - {ShapeGammaBeta::kRow, ShapeGammaBeta::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - - /// Returns a TensorRef to the A operand Scale vector - CUTLASS_HOST_DEVICE - TensorRefGammaBeta operand_A_gamma_beta_ref() { - return TensorRefGammaBeta{operand_A_gamma_beta.data(), LayoutScaleBias()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of A operand scale and bias vector - /// from shared memory - WarpIteratorGammaBeta warp_tile_iterator_A_gamma_beta_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaMainloopFusionBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_A_gamma_beta_( - shared_storage.operand_A_gamma_beta_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -}; - - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterates over vectors of var and mean vector in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorVarMean_, - /// Iterates over vectors of scale and bias vector in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorGammaBeta_, - /// Iterates over vectors of scale and bias vector in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorGammaBeta_, - /// Cache operation for scale/bias operand - cutlass::arch::CacheOperation::Kind CacheOpGammaBeta, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// WarpIterator to load Scale or Bias vector from the shared memory - typename WarpIteratorGammaBeta_, - /// Number of stages, - int Stages, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaLayernormMainloopFusionMultistage : - public MmaMainloopFusionBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of the var and mean vectors in global memory - using IteratorVarMean = IteratorVarMean_; - ///< Iterates over tiles of the scale and bias vectors in global memory - using IteratorGammaBeta = IteratorGammaBeta_; - ///< WarpIterator to load Scale or Bias vector from the shared memory - using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Base class - using Base = MmaMainloopFusionBase; - - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorGammaBeta = SmemIteratorGammaBeta_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - static cutlass::arch::CacheOperation::Kind const kCacheOpGammaBeta = - CacheOpGammaBeta; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - using WarpLoadedFragmentVarMean = typename IteratorVarMean::Fragment; - using WarpLoadedFragmentGammaBeta = - typename WarpIteratorGammaBeta::Fragment; - - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory - SmemIteratorGammaBeta smem_iterator_A_gamma_beta_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - int warp_idx_m_; - - int warp_idx_n_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaLayernormMainloopFusionMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_A_gamma_beta_(shared_storage.operand_A_gamma_beta_ref(), - thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; - warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( - {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, - IteratorGammaBeta &iterator_A_gamma_beta, - IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - // Async Copy for operand A scale and bias vector. Scale and bias vectors - // are small. One iteration is enough. - if (group_start_A == 0) { - typename IteratorGammaBeta::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_gamma_beta_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorGammaBeta::kElementsPerAccess / 8; - - cutlass::arch::cp_async( - dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over B operand in global memory - IteratorVarMean iterator_var_mean, - ///< iterator over scale and bias vectors in global memory - IteratorGammaBeta iterator_A_gamma_beta, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - // Issue several complete stages - - WarpLoadedFragmentVarMean warp_loaded_frag_var_mean; - iterator_var_mean.add_tile_offset({0, warp_idx_m_}); - iterator_var_mean.load(warp_loaded_frag_var_mean); - - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - // Async Copy for operand A scale and bias vectors. Scale and bias - // vectors are small. One iteration is enough. - { - typename IteratorGammaBeta::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_gamma_beta_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorGammaBeta::kElementsPerAccess / 8; - - cutlass::arch::cp_async( - dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_A_gamma_beta.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpLoadedFragmentGammaBeta warp_loaded_frag_A_gamma_beta[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - cutlass::gemm::warp::LayernormScaleBiasTransform - elementwise_transform; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_A_gamma_beta_.load( - warp_loaded_frag_A_gamma_beta[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_A_gamma_beta_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - elementwise_transform(warp_transformed_frag_A[0], - warp_loaded_frag_var_mean, - warp_loaded_frag_A_gamma_beta[0]); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index( - (warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_A_gamma_beta_.load( - warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_A_gamma_beta_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_loaded_frag_var_mean, - warp_loaded_frag_A_gamma_beta[warp_mma_k % 2]); - } - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, - group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_A_gamma_beta.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_A_gamma_beta_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - elementwise_transform( - warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_var_mean, - warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); - } - } - - } - - // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h deleted file mode 100644 index ed278806f5f051c2bef3ac5dc9cad3becf24bcea..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h +++ /dev/null @@ -1,741 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaMultistage : - public MmaBase { -public: - ///< Base class - using Base = MmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical - // accuracy, where each mainloop iteration first accumulates into a temporary - // set of freshly-cleared accumulators, which are subsequently added to the - // final accumulator set. - static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; - }; - - private: - - - // Structure encapsulating pipeline state live from one iteration to the next - struct PipeState { - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - /// Temporary accumulator to facilitate staged-accumulation - FragmentC tmp_accum_; - - /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; - - /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; - }; - - - private: - - // - // Data members - // - - /// Warp-level MMA operator - Operator warp_mma_; - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Shared memory write stage index - int smem_write_stage_idx_; - - /// Shared memory read stage index - int smem_read_stage_idx_; - - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - /// Advance shared memory read-iterators to the next stage - CUTLASS_DEVICE - void advance_smem_read_stage() - { - ++smem_read_stage_idx_; - - if (smem_read_stage_idx_ == Base::kStages) { - // Wrap back around to the 'start' of the circular buffer in shared memory - this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); - smem_read_stage_idx_ = 0; - } - } - - /// Advance global memory read-iterators and shared memory write-iterators to the stage - CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B) - { - // Advance global iterators - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - // Advance shared iterators - smem_iterator_A_.add_tile_offset({0, 1}); - smem_iterator_B_.add_tile_offset({1, 0}); - - // Increment shared memory write stage index - ++smem_write_stage_idx_; - - if (smem_write_stage_idx_ == Base::kStages) { - // Wrap back around to the 'start' of the circular buffer in shared memory - smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx_ = 0; - } - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching - /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - CUTLASS_DEVICE - void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining - { - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - - // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Optionally clear the remaining stages of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint are zero. - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - typename IteratorA::AccessType zero_A; - - zero_A.clear(); - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - } - - - /// Wait until we have at least one completed global fetch stage - CUTLASS_DEVICE - void gmem_wait() - { - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - } - - - /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - CUTLASS_DEVICE - void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining - { - // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load the next warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; - - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); - } - - // Execute the current warp-tile of MMA operations - if (Detail::kStagedAccumulation) { - warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); - - if (warp_mma_k == 0) { - plus plus_accum; - accum = plus_accum(accum, pipe_state.tmp_accum_); - pipe_state.tmp_accum_.clear(); - } - } else { - warp_mma_( - accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); - } - - // Except for the last warp-tile, all warp-tiles issue their share of - // global->shared fragment copies - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance( - iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - } - - // The second-to-last warp-tile also: - // - performs the last warp-tile's share of global->shared fragment copies - // - moves to the next global fetch stage - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - - // Performs the last warp-tile's share of global->shared fragment copies - int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance( - iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one completed global fetch stage - gmem_wait(); - - // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B); - advance_smem_read_stage(); - - // Disable global fetching when done with global fetch iterations - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - } - - } - } - - - /// Perform the specified number of threadblock mainloop iterations of matrix - /// multiply-accumulate. Assumes prologue has been initiated. - CUTLASS_DEVICE - void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory - { - PipeState pipe_state; - - // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - // Load first warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); - ++this->warp_tile_iterator_A_; - - // Load first warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); - ++this->warp_tile_iterator_B_; - - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); - - if (Detail::kStagedAccumulation) { - pipe_state.tmp_accum_.clear(); - } - - // Mainloop - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - mac_loop_iter( - pipe_state, - accum, - iterator_A, - iterator_B, - gemm_k_iterations); - } - - if (Detail::kStagedAccumulation) { - plus plus_accum; - accum = plus_accum(accum, pipe_state.tmp_accum_); - } - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } - - - /// Prepares the class for another prologue. - CUTLASS_DEVICE - void wind_down() - { - // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) - - // First, increment remaining warp tiles to get to the next full stage. (Ideally we would - // just decrement one tile, but not all iterators implement --() decrement.) - #pragma unroll - for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); - this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - } - smem_read_stage_idx_++; - - // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) - static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; - if (smem_read_stage_idx_ > 1) - { - this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - } - smem_read_stage_idx_ = smem_write_stage_idx_; - } - - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, gemm_k_iterations); - - // Wait until we have at least one completed global fetch stage - gmem_wait(); - - // Initialize destination accumulators with source accumulators - accum = src_accum; - - // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h deleted file mode 100644 index 87ccc0a6138ff899aa20db15c3ceca890bd29976..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h +++ /dev/null @@ -1,439 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Transformation applied to A operand - typename TransformA_ = NumericArrayConverter< - typename SmemIteratorA_::Element, - typename IteratorA_::Element, - IteratorA_::Fragment::kElements>, - /// - /// Transformation applied to B operand - typename TransformB_ = NumericArrayConverter< - typename SmemIteratorB_::Element, - typename IteratorB_::Element, - IteratorB_::Fragment::kElements>, - /// Used for partial specialization - typename Enable = bool -> -class MmaPipelined : public MmaBase { -public: - - ///< Base class - using Base = MmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - using TransformA = TransformA_; - using TransformB = TransformB_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); - -protected: - - // - // Data members - // - - /// Warp-level MMA operator - Operator warp_mma; - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - ///< transformation applied to A fragment - TransformA transform_A_; - - ///< transformation applied to B fragment - TransformB transform_B_; - - /// Shared memory write stage index - int smem_write_stage_idx; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaPipelined( - typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx, ///< ID of each thread within a warp - TransformA transform_A = TransformA(), ///< transformation applied to A fragment - TransformB transform_B = TransformB() ///< transformation applied to B fragment - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - transform_A_(transform_A), - transform_B_(transform_B), - smem_write_stage_idx(0) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - - /// Advance shared memory write-iterators to the next stage - CUTLASS_DEVICE - void advance_smem_write_stage() - { - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - - smem_write_stage_idx ^= 1; - } - - /// Advance shared memory read- and write-iterators to the next stage - CUTLASS_DEVICE - void advance_smem_stages() - { - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - // wrap write stage - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else - { - // wrap read stage - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); - } - - smem_write_stage_idx ^= 1; - } - - - /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching - /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - CUTLASS_DEVICE - void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining - { - // The last kblock is loaded in the prolog - - // Load A fragment from global A - FragmentA tb_frag_A; - tb_frag_A.clear(); - iterator_A.load(tb_frag_A); - ++iterator_A; - - // Load B fragment from global B - FragmentB tb_frag_B; - tb_frag_B.clear(); - iterator_B.load(tb_frag_B); - ++iterator_B; - - // Store A and B fragments to shared - this->smem_iterator_A_.store(transform_A_(tb_frag_A)); - this->smem_iterator_B_.store(transform_B_(tb_frag_B)); - - // Advance write stage - advance_smem_write_stage(); - } - - /// Wait until we have at least one completed global fetch stage - CUTLASS_DEVICE - void gmem_wait() - { - __syncthreads(); - } - - - /// Perform the specified number of threadblock mainloop iterations of matrix - /// multiply-accumulate. Assumes prologue has been initiated. - CUTLASS_DEVICE - void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory - { - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - // Load A fragment from shared A - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - ++this->warp_tile_iterator_A_; - - // Load B fragment from shared B - this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - ++this->warp_tile_iterator_B_; - - // Pair of fragments used to overlap global memory loads and math instructions; - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transform_A_(tb_frag_A)); - - this->smem_iterator_B_.store(transform_B_(tb_frag_B)); - - // Wait until we have at least one completed global fetch stage - gmem_wait(); - - // Advance smem read and write stages - advance_smem_stages(); - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k == 0) { - - // Load fragment from global A - tb_frag_A.clear(); - iterator_A.load(tb_frag_A); - ++iterator_A; - - // Load fragment from global B - tb_frag_B.clear(); - iterator_B.load(tb_frag_B); - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - warp_mma( - accum, - warp_frag_A[warp_mma_k % 2], - warp_frag_B[warp_mma_k % 2], - accum); - } - } - - } - - - /// Prepares the class for another prologue. - CUTLASS_DEVICE - void wind_down() - { - // First, increment remaining warp tiles to catch it up with the write stage. - #pragma unroll - for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); - this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - } - - // If we bumped the read iterators to the end of the circular buffer, wrap them around to - // align them with the write iterators - if (smem_write_stage_idx == 0) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum) ///< source accumulator tile - { - // Prologue - prologue(iterator_A, iterator_B, gemm_k_iterations); - - // Wait until we have at least one completed global fetch stage - gmem_wait(); - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h deleted file mode 100644 index b0ba5094c5d2ba0ae4ec23b0068161a54ad7ba99..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h +++ /dev/null @@ -1,208 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaPlanarComplexBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Stride to the imaginary part of the A operand - static int const kImaginaryStrideA = ShapeA::kCount; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - /// Stride to the imaginary part of the A operand - static int const kImaginaryStrideB = ShapeB::kCount; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaPlanarComplexBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h deleted file mode 100644 index 6bb9e6604f1b0cec1172e637831f2b4eb60053b0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h +++ /dev/null @@ -1,646 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/array_planar_complex.h" -#include "cutlass/functional.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Transformation applied to A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Transformation applied to B - ComplexTransform TransformB = ComplexTransform::kNone -> -class MmaPlanarComplexMultistage : - public MmaPlanarComplexBase { -public: - ///< Base class - using Base = MmaPlanarComplexBase; - - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - - ///< Data type of accumulator matrix - using ElementC = ElementC_; - - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Architecture tag - using ArchTag = arch::Sm80; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Transformation applied to A - static ComplexTransform const kTransformA = TransformA; - - /// Transformation applied to B - static ComplexTransform const kTransformB = TransformB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = ArrayPlanarComplex< - typename Policy::Operator::FragmentC::Element, - Policy::Operator::FragmentC::kElements - >; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const TBLoadIterationsA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const TBLoadIterationsB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - static int const kAccessesPerGroupA = - (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - static int const kAccessesPerGroupB = - (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaPlanarComplexMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - -private: - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA &iterator_A_real, - IteratorA &iterator_A_imag, - - IteratorB &iterator_B_real, - IteratorB &iterator_B_imag, - - int group_start_A = 0, - int group_start_B = 0) { - - iterator_A_real.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - iterator_A_imag.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Load for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - - auto gmem_ptr_real = iterator_A_real.get(); - auto gmem_ptr_imag = iterator_A_imag.get(); - - bool pred_guard = iterator_A_real.valid(); - cutlass::arch::cp_async( - dst_ptr + v, - gmem_ptr_real, - pred_guard); - cutlass::arch::cp_async( - dst_ptr + v + (Base::SharedStorage::kImaginaryStrideA / IteratorA::ThreadMap::kElementsPerAccess), - reinterpret_cast(gmem_ptr_imag), - pred_guard); - - ++iterator_A_real; - ++iterator_A_imag; - } - - ++this->smem_iterator_A_; - } - - iterator_B_real.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - iterator_B_imag.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Load for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr_real = iterator_B_real.get(); - auto gmem_ptr_imag = iterator_B_imag.get(); - - bool pred_guard = iterator_B_real.valid(); - cutlass::arch::cp_async( - dst_ptr + v, - gmem_ptr_real, - pred_guard); - cutlass::arch::cp_async( - dst_ptr + v + (Base::SharedStorage::kImaginaryStrideB / IteratorB::ThreadMap::kElementsPerAccess), - reinterpret_cast(gmem_ptr_imag), - pred_guard); - - ++iterator_B_real; - ++iterator_B_imag; - } - ++this->smem_iterator_B_; - } - } - - CUTLASS_DEVICE - void warp_mma_planar_complex( - Operator & warp_mma, - FragmentC &accum, - WarpFragmentA const & real_A, - WarpFragmentA const & imag_A, - WarpFragmentB const & real_B, - WarpFragmentB const & imag_B) { - - cutlass::negate> neg_op_B; - - WarpFragmentB neg_real_B = neg_op_B(real_B); - WarpFragmentB neg_imag_B = neg_op_B(imag_B); - - warp_mma(accum.real, real_A, real_B, accum.real); - - if (kTransformB == ComplexTransform::kNone) { - warp_mma(accum.imag, real_A, imag_B, accum.imag); - } - else { - warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); - } - - if (kTransformA == ComplexTransform::kNone) { - warp_mma(accum.imag, imag_A, real_B, accum.imag); - } - else { - warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); - } - - if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { - warp_mma(accum.real, imag_A, imag_B, accum.real); - } - else { - warp_mma(accum.real, imag_A, neg_imag_B, accum.real); - } - } - -public: - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A_real, - ///< iterator over A operand in global memory - IteratorA iterator_A_imag, - ///< iterator over B operand in global memory - IteratorB iterator_B_real, - ///< iterator over B operand in global memory - IteratorB iterator_B_imag, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); - - iterator_A_real.set_iteration_index(0); - iterator_A_imag.set_iteration_index(0); - - this->smem_iterator_A_.set_iteration_index(0); - - // Load for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { - - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - bool pred_guard = iterator_A_real.valid(); - - auto src_ptr_real = iterator_A_real.get(); - auto src_ptr_imag = iterator_A_imag.get(); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, src_ptr_real, pred_guard); - - cutlass::arch::cp_async_zfill( - dst_ptr + v + - Base::SharedStorage::kImaginaryStrideA / - IteratorA::ThreadMap::kElementsPerAccess, - reinterpret_cast(src_ptr_imag), - pred_guard); - - ++iterator_A_real; - ++iterator_A_imag; - } - - ++this->smem_iterator_A_; - } - - iterator_B_real.set_iteration_index(0); - iterator_B_imag.set_iteration_index(0); - - this->smem_iterator_B_.set_iteration_index(0); - - // Load for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { - - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - bool pred_guard = iterator_B_real.valid(); - - auto src_ptr_real = iterator_B_real.get(); - auto src_ptr_imag = iterator_B_imag.get(); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, src_ptr_real, pred_guard); - - cutlass::arch::cp_async_zfill( - dst_ptr + v + - Base::SharedStorage::kImaginaryStrideB / - IteratorB::ThreadMap::kElementsPerAccess, - reinterpret_cast(src_ptr_imag), - pred_guard); - - ++iterator_B_real; - ++iterator_B_imag; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A_real.add_tile_offset({0, 1}); - iterator_A_imag.add_tile_offset({0, 1}); - - iterator_B_real.add_tile_offset({1, 0}); - iterator_B_imag.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Inserts a memory fence between stages of cp.async instructions - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Blocks until all but kStages-2 cp.async stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - - WarpFragmentA warp_frag_real_A[2]; - WarpFragmentA warp_frag_imag_A[2]; - - WarpFragmentB warp_frag_real_B[2]; - WarpFragmentB warp_frag_imag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); - this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); - - this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); - this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); - - // Start issuing the first group of the next stage outside of the mainloop - copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); - - Operator warp_mma; - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); - - this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - } - else { - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - } - - copy_tiles_and_advance( - iterator_A_real, - iterator_A_imag, - iterator_B_real, - iterator_B_imag, - group_start_iteration_A, - group_start_iteration_B); - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - // Inserts a memory fence between stages of cp.async instructions - cutlass::arch::cp_async_fence(); - - // Blocks until all but kStages-2 cp.async stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A_real.add_tile_offset({0, 1}); - iterator_A_imag.add_tile_offset({0, 1}); - - iterator_B_real.add_tile_offset({1, 0}); - iterator_B_imag.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); - } - - warp_mma_planar_complex( - warp_mma, - accum, - warp_frag_real_A[warp_mma_k % 2], - warp_frag_imag_A[warp_mma_k % 2], - warp_frag_real_B[warp_mma_k % 2], - warp_frag_imag_B[warp_mma_k % 2]); - } - - } - - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h deleted file mode 100644 index 44585961f48a2c0de332ba9577a626f89a6da4f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h +++ /dev/null @@ -1,424 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Transformation applied to A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Transformation applied to B - ComplexTransform TransformB = ComplexTransform::kNone -> -class MmaPlanarComplexPipelined : - public MmaPlanarComplexBase { -public: - ///< Base class - using Base = MmaPlanarComplexBase; - - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - - ///< Data type of accumulator matrix - using ElementC = ElementC_; - - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - - ///< Policy describing tuning details - using Policy = Policy_; - - using ArchTag = typename Policy::Operator::ArchTag; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - /// Transformation applied to A - static ComplexTransform const kTransformA = TransformA; - - /// Transformation applied to B - static ComplexTransform const kTransformB = TransformB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = ArrayPlanarComplex< - typename Policy::Operator::FragmentC::Element, - Policy::Operator::FragmentC::kElements - >; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - private: - - using FragmentA = typename IteratorA::Fragment; - using FragmentB = typename IteratorB::Fragment; - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaPlanarComplexPipelined( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - -private: - - CUTLASS_DEVICE - void warp_mma_planar_complex( - Operator & warp_mma, - FragmentC &accum, - WarpFragmentA const & real_A, - WarpFragmentA const & imag_A, - WarpFragmentB const & real_B, - WarpFragmentB const & imag_B) { - - cutlass::negate> neg_op_B; - - WarpFragmentB neg_real_B = neg_op_B(real_B); - WarpFragmentB neg_imag_B = neg_op_B(imag_B); - - warp_mma(accum.real, real_A, real_B, accum.real); - - if (kTransformB == ComplexTransform::kNone) { - warp_mma(accum.imag, real_A, imag_B, accum.imag); - } - else { - warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); - } - - if (kTransformA == ComplexTransform::kNone) { - warp_mma(accum.imag, imag_A, real_B, accum.imag); - } - else { - warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); - } - - if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { - warp_mma(accum.real, imag_A, imag_B, accum.real); - } - else { - warp_mma(accum.real, imag_A, neg_imag_B, accum.real); - } - } - -public: - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A_real, - ///< iterator over A operand in global memory - IteratorA iterator_A_imag, - ///< iterator over B operand in global memory - IteratorB iterator_B_real, - ///< iterator over B operand in global memory - IteratorB iterator_B_imag, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A_real; - FragmentA tb_frag_A_imag; - - FragmentB tb_frag_B_real; - FragmentB tb_frag_B_imag; - - tb_frag_A_real.clear(); - tb_frag_A_imag.clear(); - - tb_frag_B_real.clear(); - tb_frag_B_imag.clear(); - - // The last kblock is loaded in the prolog - iterator_A_real.load(tb_frag_A_real); - iterator_A_imag.load(tb_frag_A_imag); - - iterator_B_real.load(tb_frag_B_real); - iterator_B_imag.load(tb_frag_B_imag); - - ++iterator_A_real; - ++iterator_A_imag; - - ++iterator_B_real; - ++iterator_B_imag; - - this->smem_iterator_A_.store(tb_frag_A_real); - this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); - - this->smem_iterator_B_.store(tb_frag_B_real); - this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_real_A[2]; - WarpFragmentA warp_frag_imag_A[2]; - - WarpFragmentB warp_frag_real_B[2]; - WarpFragmentB warp_frag_imag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); - this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); - - this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); - this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); - - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A_real.clear_mask(gemm_k_iterations <= 1); - iterator_A_imag.clear_mask(gemm_k_iterations <= 1); - - iterator_B_real.clear_mask(gemm_k_iterations <= 1); - iterator_B_imag.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tightest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(tb_frag_A_real); - this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); - - this->smem_iterator_B_.store(tb_frag_B_real); - this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); - - __syncthreads(); - - ++this->smem_iterator_B_; - ++this->smem_iterator_A_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, - 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); - - this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k == 0) { - - iterator_A_real.load(tb_frag_A_real); - iterator_A_imag.load(tb_frag_A_imag); - - iterator_B_real.load(tb_frag_B_real); - iterator_B_imag.load(tb_frag_B_imag); - - ++iterator_A_real; - ++iterator_A_imag; - ++iterator_B_real; - ++iterator_B_imag; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A_real.clear_mask(gemm_k_iterations <= 2); - iterator_A_imag.clear_mask(gemm_k_iterations <= 2); - iterator_B_real.clear_mask(gemm_k_iterations <= 2); - iterator_B_imag.clear_mask(gemm_k_iterations <= 2); - } - - warp_mma_planar_complex( - warp_mma, - accum, - warp_frag_real_A[warp_mma_k % 2], - warp_frag_imag_A[warp_mma_k % 2], - warp_frag_real_B[warp_mma_k % 2], - warp_frag_imag_B[warp_mma_k % 2]); - } - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h deleted file mode 100644 index 3caba9f3110e31157692fc3dccbfd2842b305996..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h +++ /dev/null @@ -1,265 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/aligned_buffer.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" - - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Used for partial specialization - typename Enable = bool -> -class MmaSingleStage : public MmaBase { -public: - - ///< Base class - using Base = MmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - using ArchTag = arch::Sm70; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for MmaSingleStage is 1 (single stage mma pipeline) - static_assert((Base::kStages==1), "MmaSingleStage requires kStages set to value 1"); -private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - -protected: - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaSingleStage( - typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum) { ///< source accumulator tile - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A; - WarpFragmentB warp_frag_B; - - Operator warp_mma; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - this->smem_iterator_A_.store(tb_frag_A); - this->smem_iterator_B_.store(tb_frag_B); - - __syncthreads(); - - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_frag_A); - this->warp_tile_iterator_B_.load(warp_frag_B); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - warp_mma(accum, warp_frag_A, warp_frag_B, accum); - } - - // Add negative offsets to return smem load iterators to the 'start' of the shared memory - this->warp_tile_iterator_A_.add_tile_offset({0, -Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); - - __syncthreads(); - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h deleted file mode 100644 index 5174be4babd78b5698ad7e6e4ac28134175f4a0b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h +++ /dev/null @@ -1,756 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. - - It loads two loop invariant vectors, norm and sum, in the prologue and - stores them in the register file. We will call elementwise operation to - apply norm and sum between ldmatrix and warp mma. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/gemm/warp/softmax_scale_bias_transform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class MmaMainloopFusionBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = cutlass::gemm::GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaMainloopFusionBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -}; - - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterates over vectors of var and mean vector in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorNormSum_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Whether problem has been transformed. This determines to which operand - /// the softmax is applied. - bool InternalTranspose, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaSoftmaxMainloopFusionMultistage : - public MmaMainloopFusionBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of the var and mean vectors in global memory - using IteratorNormSum = IteratorNormSum_; - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Base class - using Base = MmaMainloopFusionBase; - - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - using WarpLoadedFragmentNormSum = typename IteratorNormSum::Fragment; - - static bool const kInternalTranspose = InternalTranspose; - - using SoftmaxFragment = typename platform::conditional::type; - - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - int warp_idx_m_; - - int warp_idx_n_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaSoftmaxMainloopFusionMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; - warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, - IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over B operand in global memory - IteratorNormSum iterator_norm_sum, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - // Issue several complete stages - - WarpLoadedFragmentNormSum warp_loaded_frag_norm_sum; - iterator_norm_sum.add_tile_offset({0, warp_idx_m_}); - iterator_norm_sum.load(warp_loaded_frag_norm_sum); - - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - cutlass::gemm::warp::SoftmaxScaleBiasTransform< - SoftmaxFragment, WarpLoadedFragmentNormSum> elementwise_transform; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - // Start issuing the first group of the next stage outside of the mainloop - copy_tiles_and_advance(iterator_A, iterator_B); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - if (kInternalTranspose) { - elementwise_transform(warp_transformed_frag_B[0], - warp_loaded_frag_norm_sum); - } else { - elementwise_transform(warp_transformed_frag_A[0], - warp_loaded_frag_norm_sum); - } - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - if (kInternalTranspose) { - elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_norm_sum); - } else { - elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_loaded_frag_norm_sum); - } - } - - // Issue global->shared copies for the next stage - int group_start_iteration_A, group_start_iteration_B; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - group_start_iteration_A = 0; - group_start_iteration_B = 0; - } else { - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - } - - copy_tiles_and_advance(iterator_A, iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - if (kInternalTranspose) { - elementwise_transform(warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_norm_sum); - } else { - elementwise_transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_norm_sum); - } - } - } - - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h deleted file mode 100644 index 9e94b0ffbf54678d8de3b51ec75bfa2c7966d54b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h +++ /dev/null @@ -1,273 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Policy object describing MmaTensorOp -template < - /// Warp-level GEMM operator (concept: gemm::warp::Mma) - typename Operator_, - /// Padding used for A operand in shared memory (concept: MatrixShape) - typename SmemPaddingA_, - /// Padding used for B operand in shared memory (concept: MatrixShape) - typename SmemPaddingB_, - /// Padding used for E operand in shared memory (concept: MatrixShape) - typename SmemPaddingE_, - /// Number of partitions of K dimension of GEMM - int PartitionsK = 1> -struct SparseMmaPolicy { - /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) - using Operator = Operator_; - - /// Padding used for A operand in shared memory - using SmemPaddingA = SmemPaddingA_; - - /// Padding used for B operand in shared memory - using SmemPaddingB = SmemPaddingB_; - - /// Padding used for B operand in shared memory - using SmemPaddingE = SmemPaddingE_; - - /// Number of partitions of K dimension - static int const kPartitionsK = PartitionsK; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class SparseMmaBase { - public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = - (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - static_assert(kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - static_assert((kWarpGemmIterations % 2) == 0, - "Inner loop iteration must be an even number."); - - /// Number of stages - static int const kStages = Stages; - - static int const kSparse = Operator::kSparse; - - static int const kElementsPerElementE = Operator::kElementsPerElementE; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - /// Tensor reference to the E operand - using TensorRefE = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - /// Shape of the E matrix operand in shared memory - using ShapeE = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer for E operand - AlignedBuffer operand_E; - - public: - - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a layout object for the E matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutE LayoutE() { - return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() { - return TensorRefB{operand_B.data(), LayoutB()}; - } - - /// Returns a TensorRef to the E operand - CUTLASS_HOST_DEVICE - TensorRefE operand_E_ref() { - return TensorRefE{operand_E.data(), LayoutE()}; - } - }; - - protected: - - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - - /// Iterator to load a warp-scoped tile of E operand from shared memory - typename Operator::IteratorE warp_tile_iterator_E_; - - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - SparseMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), - warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) { - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h deleted file mode 100644 index 8bc23c3fb77596ed3529dae8ec543c80b6060526..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ /dev/null @@ -1,668 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/threadblock/mma_sparse_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Iterates over tiles of E operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorE_, - /// Iterates over tiles of E operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorE_, - /// Cache operation for operand E - cutlass::arch::CacheOperation::Kind CacheOpE, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class SparseMmaMultistage : - public SparseMmaBase { -public: - ///< Base class - using Base = SparseMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of E operand in global memory - using IteratorE = IteratorE_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorE = SmemIteratorE_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE; - - static int const kSparse = Policy::Operator::kSparse; - static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; - static int const kMaxID2 = Policy::Operator::kMaxID2; - static int const kElementsPerElementE = - Policy::Operator::kElementsPerElementE; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// ElementE - using ElementE = typename IteratorE::Element; - - /// LayoutE - using LayoutE = typename IteratorE::Layout; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of async copies to load one stage of operand A - static int const TBLoadIterationsA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of async copies to load one stage of operand B - static int const TBLoadIterationsB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of async copies to load one stage of operand E - static int const TBLoadIterationsE = - IteratorE::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of async copies to load one group of operand A - static int const kAccessesPerGroupA = - (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of async copies to load one group of operand B - static int const kAccessesPerGroupB = - (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of async copies to load one group of operand E - static int const kAccessesPerGroupE = - (TBLoadIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// E operand is tiny. For the most of time, not all the warps are needed - /// to load it from the global memory. - static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32; - - /// B operand is twice as big as A which brings very high register pressure. - /// We have to sacrifice the double buffer when the warp tile size is big. - static int const kBBufferSize = - ((sizeof(typename Operator::ElementC) == 4) && - ((platform::is_same::value && - platform::is_same::value)) && - (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) - ? 1 - : 2; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - using WarpFragmentE = typename Operator::FragmentE; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of E operand to shared memory - SmemIteratorE smem_iterator_E_; - - /// Warp id - bool is_warp_valid_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - SparseMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx) - { - is_warp_valid_ = warp_idx < Detail::kValidWarps; - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - this->warp_tile_iterator_E_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, - IteratorE &iterator_E, int group_start_A = 0, - int group_start_B = 0, int group_start_E = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // async copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::TBLoadIterationsA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // async copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::TBLoadIterationsB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - - iterator_E.set_iteration_index(group_start_E); - this->smem_iterator_E_.set_iteration_index(group_start_E); - - // async copy for operand E - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) { - if (group_start_E + j < Detail::TBLoadIterationsE) { - typename IteratorE::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_E_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorE::ThreadMap::kElementsPerAccess / 8; - - auto gmem_ptr = iterator_E.get(); - - cutlass::arch::cp_async( - dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_); - - ++iterator_E; - ++this->smem_iterator_E_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over E operand in global memory - IteratorE iterator_E, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // async copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // async copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - iterator_E.set_iteration_index(0); - this->smem_iterator_E_.set_iteration_index(0); - - // async copy for operand E - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLoadIterationsE; ++j) { - typename IteratorE::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_E_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorE::ThreadMap::kElementsPerAccess / 8; - if (is_warp_valid_) - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_E.get(), iterator_E.valid()); - - ++iterator_E; - - ++this->smem_iterator_E_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - iterator_E.add_tile_offset({0, 1}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - this->smem_iterator_E_.add_tile_offset({0, 1}); - - // cp.async.commit_group - completes a stage - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize]; - WarpFragmentE warp_frag_E[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_E_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - this->warp_tile_iterator_E_.load(warp_frag_E[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - ++this->warp_tile_iterator_E_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_E_; - - if (Detail::kBBufferSize == 2) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load( - warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]); - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum, - warp_frag_E[warp_mma_k % 2] - ); - - if (Detail::kBBufferSize == 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - ++this->warp_tile_iterator_B_; - - } - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE; - - copy_tiles_and_advance( - iterator_A, iterator_B, iterator_E, group_start_iteration_A, - group_start_iteration_B, group_start_iteration_E); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - group_start_iteration_E = - (warp_mma_k + 1) * Detail::kAccessesPerGroupE; - - copy_tiles_and_advance( - iterator_A, iterator_B, iterator_E, group_start_iteration_A, - group_start_iteration_B, group_start_iteration_E); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - iterator_E.add_tile_offset({0, 1}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - this->smem_iterator_E_.add_tile_offset({0, 1}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_E_.add_tile_offset({0, -Base::kStages}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - this->warp_tile_iterator_E_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); - } - - } - - - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h deleted file mode 100644 index 2fd49a5bc462d81040abd463098a357f5eab2465..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h +++ /dev/null @@ -1,545 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/threadblock/mma_base.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class MmaWithReductionMultistage : - public MmaBase { -public: - ///< Base class - using Base = MmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - using FragmentReduction = typename Operator::FragmentReduction; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - static int const kReduceKForA = Operator::kReduceKForA; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - - private: - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - - private: - - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - MmaWithReductionMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, - int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum, - FragmentReduction &gemm_k_reduction_accum) { - - // - // Prologue - // - // Issue several complete stages - - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum, - gemm_k_reduction_accum - ); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - } - - } - - // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h deleted file mode 100644 index 9495d785536910355a5d0f9a3cd91dc7b5895747..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ /dev/null @@ -1,459 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implements several possible threadblock-swizzling functions mapping blockIdx to - GEMM problems. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/platform/platform.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/gemm/threadblock/index_remat.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle_streamk.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for GEMMs -template -struct GemmIdentityThreadblockSwizzle { - - CUTLASS_HOST_DEVICE - GemmIdentityThreadblockSwizzle() { } - - /// Returns the shape of the problem in units of logical tiles - /// *Gemm* problem size: gemm(M, N, K) - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - GemmCoord problem_size, - GemmCoord tile_size, - int split_k_slices) { - - return GemmCoord( - (problem_size.m() + tile_size.m() - 1) / tile_size.m(), - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - split_k_slices); - } - - /// Returns the shape of the problem in units of logical tiles - /// *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem_size, - GemmCoord tile_size, - int split_k_slices) { - - gemm::GemmCoord implicit_gemm_problem_size = - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); - - return get_tiled_shape( - implicit_gemm_problem_size, tile_size, split_k_slices); - } - - /// Returns the shape of the problem in units of logical tiles - /// *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv3dProblemSize const &problem_size, - GemmCoord tile_size, - int split_k_slices) { - - gemm::GemmCoord implicit_gemm_problem_size = - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); - - return get_tiled_shape( - implicit_gemm_problem_size, tile_size, split_k_slices); - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(GemmCoord tiled_shape) { - int tile = 1 << get_log_tile(tiled_shape); - return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - auto n = tiled_shape.n(); - // Thresholds picked so that it doesn't cause too many no-op CTAs - if (N >= 8 && n >= 6) - return 3; - else if (N >= 4 && n >= 3) - return 2; - else if (N >= 2 && n >= 2) - return 1; - else - return 0; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(int log_tile) { - int block_idx_x = RematerializeBlockIdxX(); - int block_idx_y = RematerializeBlockIdxY(); - int block_idx_z = RematerializeBlockIdxZ(); - - return GemmCoord{(block_idx_x >> log_tile), // - (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), - block_idx_z}; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(GemmCoord tiled_shape) { - - int const kTile = N; - int block_idx_x = RematerializeBlockIdxX(); - int block_idx_y = RematerializeBlockIdxY(); - - if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) - return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; - - return GemmCoord{ - (block_idx_x / kTile), - (block_idx_y * kTile) + (block_idx_x % kTile), - RematerializeBlockIdxZ() - }; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for GEMMs -struct GemmHorizontalThreadblockSwizzle { - - CUTLASS_HOST_DEVICE - GemmHorizontalThreadblockSwizzle() { } - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - GemmCoord problem_size, - GemmCoord tile_size, - int split_k_slices) { - - return GemmCoord( - (problem_size.m() + tile_size.m() - 1) / tile_size.m(), - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - split_k_slices); - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(GemmCoord tiled_shape) { - return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - return 0; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(GemmCoord tiled_shape) { - return GemmCoord{ - RematerializeBlockIdxY(), - RematerializeBlockIdxX(), - RematerializeBlockIdxZ() - }; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for batched GEMMs -struct GemmBatchedIdentityThreadblockSwizzle { - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - GemmCoord problem_size, - GemmCoord tile_size, - int batch_count) { - - return GemmCoord( - (problem_size.m() + tile_size.m() - 1) / tile_size.m(), - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - batch_count % (1 << 16)); - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(GemmCoord tiled_shape) { - return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - return 0; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(GemmCoord tiled_shape) { - return GemmCoord{ - RematerializeBlockIdxX(), - RematerializeBlockIdxY(), - RematerializeBlockIdxZ() - }; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(int log_tile) { - int block_idx_x = RematerializeBlockIdxX(); - int block_idx_y = RematerializeBlockIdxY(); - int block_idx_z = RematerializeBlockIdxZ(); - - return GemmCoord{(block_idx_x >> log_tile), // - (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), - block_idx_z}; - } - - /// Gets the batch index - CUTLASS_DEVICE - static int get_batch_idx() { - return RematerializeBlockIdxZ(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for split-K GEMMs -template -struct GemmSplitKIdentityThreadblockSwizzle { - - int const kTile = N; - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - GemmCoord problem_size, - GemmCoord tile_size, - int partitions) { - - return GemmCoord( - (problem_size.m() + tile_size.m() - 1) / tile_size.m(), - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - partitions); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - auto n = tiled_shape.n(); - // Thresholds picked so that it doesn't cause too many no-op CTAs - if (N >= 8 && n >= 6) - return 3; - else if (N >= 4 && n >= 3) - return 2; - else if (N >= 2 && n >= 2) - return 1; - else - return 0; - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(GemmCoord tiled_shape) { - int tile = 1 << get_log_tile(tiled_shape); - return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(int log_tile) { - int block_idx_x = RematerializeBlockIdxX(); - int block_idx_y = RematerializeBlockIdxY(); - int block_idx_z = RematerializeBlockIdxZ(); - - return GemmCoord{(block_idx_x >> log_tile), // - (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), - block_idx_z}; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(GemmCoord tiled_shape) { - - int const kTile = N; - int block_idx_x = RematerializeBlockIdxX(); - int block_idx_y = RematerializeBlockIdxY(); - - if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) - return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; - - return GemmCoord{ - (block_idx_x / kTile), - (block_idx_y * kTile) + (block_idx_x % kTile), - RematerializeBlockIdxZ() - }; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for split-K GEMMs -struct GemmSplitKHorizontalThreadblockSwizzle { - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static GemmCoord get_tiled_shape( - GemmCoord problem_size, - GemmCoord tile_size, - int partitions) { - - return GemmCoord( - (problem_size.m() + tile_size.m() - 1) / tile_size.m(), - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - partitions); - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(GemmCoord tiled_shape) { - return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - return 0; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(int log_tile) { - return GemmCoord{ - RematerializeBlockIdxY(), - RematerializeBlockIdxX(), - RematerializeBlockIdxZ() - }; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static GemmCoord get_tile_offset(GemmCoord tiled_shape) { - return GemmCoord{ - RematerializeBlockIdxY(), - RematerializeBlockIdxX(), - RematerializeBlockIdxZ() - }; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock swizzling function for batched GEMVs -struct GemvBatchedStridedThreadblockDefaultSwizzle { - - /// Returns the shape of the problem in units of logical tiles - CUTLASS_HOST_DEVICE - static BatchedGemmCoord get_tiled_shape( - BatchedGemmCoord problem_size, - BatchedGemmCoord tile_size) { - - return BatchedGemmCoord( - 1, // M is always 1 - (problem_size.n() + tile_size.n() - 1) / tile_size.n(), - (problem_size.k() + tile_size.k() - 1) / tile_size.k(), - (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch()); - } - - /// Computes CUDA grid dimensions given a size in units of logical tiles - CUTLASS_HOST_DEVICE - static dim3 get_grid_shape(BatchedGemmCoord tiled_shape) { - return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k()); - } - - /// Calculates optimal swizzle width - CUTLASS_HOST_DEVICE - static int get_log_tile(GemmCoord tiled_shape) { - return 0; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static BatchedGemmCoord get_tile_offset(int log_tile) { - return BatchedGemmCoord{ - 0, // M is always 1 - RematerializeBlockIdxX(), - RematerializeBlockIdxZ(), - RematerializeBlockIdxY(), - }; - } - - /// Obtains the threadblock offset (in units of threadblock-scoped tiles) - CUTLASS_DEVICE - static BatchedGemmCoord get_tile_offset() { - return BatchedGemmCoord{ - 0, // M is always 1 - RematerializeBlockIdxX(), - RematerializeBlockIdxZ(), - RematerializeBlockIdxY(), - }; - } - - /// Gets the batch tile index - CUTLASS_DEVICE - static int get_batch_tile_idx() { - return RematerializeBlockIdxY(); - } - - /// Gets the absolute batch index - CUTLASS_DEVICE - static int get_batch_idx() { - return RematerializeBlockDimY()*RematerializeBlockIdxY() + RematerializeThreadIdxY(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h deleted file mode 100644 index da54eee5a7618c61fc0b9736418ae05ce0466bce..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ /dev/null @@ -1,801 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implements streamk threadblock mapping blockIdx to GEMM problems. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/platform/platform.h" -#include "cutlass/gemm/gemm_enumerated_types.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/gemm/threadblock/index_remat.h" - -#if !defined(__CUDACC_RTC__) -#include -#include "cutlass/core_io.h" -#include "cutlass/trace.h" -#endif - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Threadblock mapping control for GEMMs -struct ThreadblockSwizzleStreamK { - - /// Advertise StreamkFeature - using StreamkFeature = void; - - - /// Kernel traits - template - struct KernelTraits {}; - - - /// Reduction strategy - enum ReductionStrategy - { - kNone, // Data-parallel strategy (no seams, fixup, etc.) - - kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2 - - kMixed, // Deterministic reduction of SK-block partials employing either: - // (a) A separate wave of reduction thread blocks" (for scenarios with lots of - // SK-blocks per SK-tile) - // (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few - // SK-blocks per SK-tile) - }; - - static ReductionStrategy const kReductionStrategy = kMixed; - - - // - // Heuristics - // - - /// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel) - static float constexpr kDpEfficiencyThreshold = 0.92f; - - /// Minimum number of MAC-iterations per streamk block - static int const kMinItersPerSkBlock = 2; - - /// Height in CTAs of a grid rasterization cohort - static int const kCohortCtasM = 8; - - /// Width in CTAs of a grid rasterization cohort - static int const kCohortCtasN = 4; - - /// Number of CTAs per cohort - static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM; - - /// Cost-equivalent number of SM-iterations for fixup I/O - static int const kFixupStartupIterEquiv = 10; - static int const kFixupPeerIterEquiv = 3; - - - // - // Member state - // - - - /// The 3D value-extents of the GEMM computation volume (m,n,k) - GemmCoord problem_size; - - /// Div/mod accelerators - FastDivmod div_mod_tiled_shape_m; - FastDivmod div_mod_tiled_shape_n; - FastDivmod div_mod_tiled_cohort_shape_n; - FastDivmod div_mod_iters_per_tile; - - /// Whether to perform cohort CTA rasterization - bool cohort_raster; - - // Whether to pad and remap block indices - bool remap_block_indices; - - /// CTA occupancy per SM - int sm_occupancy; - - /// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size) - int avail_sms; - - int dp_blocks; /// Number of data-parallel thread blocks in the grid - int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce - - /// Number of reduction blocks in the grid - int reduction_blocks; - - int sk_waves; - int sk_tiles; - int sk_big_blocks_per_region; - int sk_iters_per_region; - - /// Div/mod accelerators - FastDivmod div_mod_sk_iters_per_normal_block; - FastDivmod div_mod_sk_iters_per_big_block; - FastDivmod div_mod_sk_iters_per_region; - FastDivmod div_mod_sk_regions; //!! used in block map - FastDivmod div_mod_sk_blocks_per_region; //!! used in block map - - /// The batch count - int batch_count; - - - // - // Host+device interface - // - - /// Constructor - ThreadblockSwizzleStreamK() = default; - - /// Returns the GEMM volume in thread block tiles - CUTLASS_HOST_DEVICE - GemmCoord tiled_shape() const - { - return GemmCoord( - static_cast(div_mod_tiled_shape_m), - static_cast(div_mod_tiled_shape_n), - batch_count); - } - - /// Number of iterations per output tile - CUTLASS_HOST_DEVICE - int iters_per_tile() const - { - return static_cast(div_mod_iters_per_tile); - } - - /// Number of iterations for normal SK-blocks - CUTLASS_HOST_DEVICE - int sk_iters_per_normal_block() const - { - return static_cast(div_mod_sk_iters_per_normal_block); - } - - /// Number of SK regions - CUTLASS_HOST_DEVICE - int sk_regions() const - { - return static_cast(div_mod_sk_regions); - } - - /// Number of SK blocks per region (splitting factor) - CUTLASS_HOST_DEVICE - int sk_blocks_per_region() const - { - return static_cast(div_mod_sk_blocks_per_region); - } - - - // - // Host-side interface - // - - /// Debug print - void Print() - { -#ifndef __CUDA_ARCH__ - auto tiles = tiled_shape().mn().product(); - std::cout << - "problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" << - ", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" << - ", tiles: " << tiles << - ", dp_tiles: " << tiles - sk_tiles << - ", sk_tiles: " << sk_tiles << - ", iters_per_tile: " << iters_per_tile() << - ", reduction_blocks: " << reduction_blocks << - ", dp_blocks: " << dp_blocks << - ", dp_waves: " << dp_blocks / avail_sms << - ", dp_first_wave_tiles: " << dp_first_wave_tiles << - ", sk_blocks_per_region: " << sk_blocks_per_region() << - ", sk_regions: " << sk_regions() << - ", sk_waves: " << sk_waves << - ", sk_iters_per_normal_block: " << sk_iters_per_normal_block() << - ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << - ", remap_block_indices: " << remap_block_indices << - ", cohort_raster: " << cohort_raster << - ", sm_occupancy: " << sm_occupancy << - ", avail_sms: " << avail_sms << - ", num_blocks: " << get_num_blocks() << - "\n\n"; -#endif - } - - - // Compute sk_blocks to dispatch for a given number of sk_tiles - static void get_sk_blocks( - int &sk_blocks, /// [out] - int &savings_iters, /// [out] - int sk_tiles, - int iters_per_tile, - int avail_sms, - int max_sk_occupancy, - bool allow_partial_wave) - { - savings_iters = INT_MIN; - sk_blocks = 0; - - if (sk_tiles == 0) { - return; - } - - int sk_iters = sk_tiles * iters_per_tile; - - int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms; - int dp_equiv_iters = iters_per_tile * dp_equiv_waves; - - int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms; - int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock); - - for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks) - { - int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms; - int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks; - int sk_iter_equiv = max_sk_iters_per_block * sk_waves; - - int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew - - float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv); - - if (trial_sk_blocks % sk_tiles == 0) - { - // aligned - num_peers = (trial_sk_blocks / sk_tiles); - - iter_cost = 0.0f; - } - - float peer_cost = 2.0f * float(num_peers); - - float base_cost = 2.0f * float(sk_waves); - - int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost); - - int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv; - - if (trial_savings_iters >= savings_iters) { - savings_iters = trial_savings_iters; - sk_blocks = trial_sk_blocks; - } - } - } - - - /// Determine the populations of DP and SK blocks to invoke for the given number of output tiles - static void get_blocks( - int &dp_tiles, /// [out] - int &sk_blocks, /// [out] - int output_tiles, - int iters_per_tile, - int avail_sms, - int sm_occupancy) - { - int full_waves = output_tiles / avail_sms; - int full_wave_tiles = full_waves * avail_sms; - int partial_wave_tiles = output_tiles - full_wave_tiles; - - int score = -1; - dp_tiles = output_tiles; - sk_blocks = 0; - - if (partial_wave_tiles == 0) - { - // Perfect quantization - return; - } - - if (full_waves < sm_occupancy) - { - // We're less than full GPU occupancy - - // Form the SK wave from the partial wave to get us up to full GPU occupancy - int max_sk_occupancy = sm_occupancy - full_waves; - - dp_tiles = full_wave_tiles; - - get_sk_blocks( - sk_blocks, - score, - partial_wave_tiles, - iters_per_tile, - avail_sms, - max_sk_occupancy, - true); // we can run with less than a full wave of SK-blocks - - if (score < 0) { - // not profitable - sk_blocks = 0; - dp_tiles = output_tiles; - } - - return; - } - - // We're at (or greater) than GPU occupancy - - if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1)) - { - // If occupancy is more than one CTA per SM, form the SK wave from the partial - // wave to get us to full GPU occupancy - int max_sk_occupancy = 1; - - dp_tiles = full_wave_tiles; - - get_sk_blocks( - sk_blocks, - score, - partial_wave_tiles, - iters_per_tile, - avail_sms, - max_sk_occupancy, - true); // we can run with less than a full wave of SK-blocks - - if (score >= 0) { - return; - } - } - - // Form the SK wave by combining the last full wave and the partial wave - // We're less than full GPU occupancy - dp_tiles = full_wave_tiles - avail_sms; - - int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy); - - get_sk_blocks( - sk_blocks, - score, - partial_wave_tiles + avail_sms, - iters_per_tile, - avail_sms, - max_sk_occupancy, - false); // we cannot run with less than a full wave of SK-blocks - - if (score < 0) { - // not profitable - sk_blocks = 0; - dp_tiles = output_tiles; - } - - } - - /// Constructor: *Gemm* problem size (m, n, k) - ThreadblockSwizzleStreamK( - GemmUniversalMode const mode_, - GemmCoord const problem_size_, - GemmCoord const tile_size_, - int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) - int const sm_occupancy_, - int const device_sms_, - int const avail_sms_, /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) - size_t const element_A_bytes_, - size_t const element_B_bytes_, - size_t const element_C_bytes_, - int const epilogue_acc_fragments_) - : - problem_size(problem_size_), - batch_count((mode_ == GemmUniversalMode::kBatched || mode_ == GemmUniversalMode::kArray) ? batch_split_ : 1), - reduction_blocks(0), - dp_blocks(0), - dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks - sk_tiles(0), - sk_big_blocks_per_region(0), - sk_iters_per_region(0), - sk_waves(0), - sm_occupancy(sm_occupancy_), - remap_block_indices(false), - avail_sms(fast_max(1, avail_sms_)), - cohort_raster(false) - { - int gpu_occupancy = device_sms_ * sm_occupancy; - int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k(); - int sk_iters_per_normal_block = 0; - - int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles) - int sk_blocks_per_region = 0; - - GemmCoord tiled_shape( - (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), - (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), - batch_count); - - size_t problem_bytes = - (element_C_bytes_ * problem_size.m() * problem_size.n()) + - (element_A_bytes_ * problem_size.m() * problem_size.k()) + - (element_B_bytes_ * problem_size.k() * problem_size.n()); - - size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2; - - [[maybe_unused]] float flops_per_byte = float(problem_flops) / float(problem_bytes); - - int output_tiles = tiled_shape.m() * tiled_shape.n(); - int waves = (output_tiles + avail_sms - 1) / avail_sms; - [[maybe_unused]] float dp_efficiency = float(output_tiles) / float(waves * avail_sms); - - // - // Determine dispatch composition of DP-tiles and SK-blocks - // - - // Start with a DP-only configuration - int dp_tiles = output_tiles; // Number of data-parallel tiles - int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles - - // Only kGemm mode allows for SK load balancing - if (mode_ == GemmUniversalMode::kGemm) - { - int split_factor = batch_split_; - if (split_factor > 1) - { - // Split-K override - dp_tiles = 0; - sk_blocks = output_tiles * split_factor; - } - else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled - (avail_sms > 1)) // Plurality of SMs to load balance across - { - // Use heuristics - get_blocks( - dp_tiles, /// [out] - sk_blocks, /// [out] - output_tiles, - iters_per_tile, - avail_sms, - sm_occupancy); - } - } - - sk_tiles = output_tiles - dp_tiles; - - - // Compute SK block iteration details - if (sk_blocks > 0) - { - sk_waves = (sk_blocks + avail_sms - 1) / avail_sms; - - int sk_iters = sk_tiles * iters_per_tile; - sk_blocks = fast_min(sk_blocks, sk_iters); - - sk_iters_per_normal_block = sk_iters / sk_blocks; - int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks); - int sk_big_blocks = extra_sk_iters; - - if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0)) - { - // Split-K decomposition - sk_regions = sk_tiles; - } - - sk_blocks_per_region = sk_blocks / sk_regions; - sk_big_blocks_per_region = sk_big_blocks / sk_regions; - sk_iters_per_region = sk_iters / sk_regions; - - // Use a separate reduction wave when all of: - // - Non-atomic reduction stratgy - // - The number of SK waves won't fully occupy the GPU (Otherwise we don't have - // a strong-scaling case for more parallel reduction) - // - More than three peers working on an SK tile. (This occurs when the ratio of - // SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks, - // e.g.:[partial-block | block | block | partial-block] ). With three or - // less peers, the two non-finishing SK-blocks are not expected to contend. - if ((kReductionStrategy == kMixed) && - (sk_waves < sm_occupancy) && - (sk_blocks > 2 * sk_tiles)) - { - // Launch a reduction block for every accumulator fragment in each SK-tile - reduction_blocks = sk_tiles * epilogue_acc_fragments_; - - } - - // When we have a multi-occupancy kernel and at least two waves of active blocks (where - // at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2) - // remap the block indices so that we can reliably spread the SK blocks evenly across the - // device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx(). - remap_block_indices = ( - (sm_occupancy > 1) && - (device_sms_ == avail_sms) && - (get_num_active_blocks() > avail_sms * 2)); - - // Initialize fast div/mod members related to SK - div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); - div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); - div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region); - div_mod_sk_regions = FastDivmod(sk_regions); - div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region); - } - - // - // Compute DP blocks - // - - dp_blocks = dp_tiles; - - cutlass::gemm::GemmCoord tiled_cohort_shape( - (tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM, - (tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN, - tiled_shape.k()); - int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort; - float cohort_efficiency = float(dp_blocks) / float(cohort_blocks); - - // Check if the SK tiles would be in cohorts that are in-bounds - bool sk_in_range = true; - if (sk_tiles > 0) - { - int last_sk_tile = sk_tiles - 1; - int cohort_tile_idx = last_sk_tile / kCtasPerCohort; - int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n(); - int cohort_grid_n = (cohort_grid_m > 0) ? - tiled_cohort_shape.n() - 1 : - cohort_tile_idx % tiled_cohort_shape.n(); - - if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) || - (((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n())) - { - sk_in_range = false; - } - - } - - // Decide if we're going to be doing cohort raster - if (sk_in_range && - (dp_blocks >= gpu_occupancy * 2) && - (cohort_efficiency > 0.85f)) - { - cohort_raster = true; - dp_blocks = cohort_blocks; - } - else if (sk_waves > 0) - { - // Update semi-persistence of first DP wave to ensure full grid wavesets - // (Only applies when there's an SK component and we're not doing blocked cohort rasterization) - int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms; - int full_dp_tile_waves = dp_tiles / avail_sms; - int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy; - - if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves) - { - dp_first_wave_tiles += waveset_excess; - dp_blocks -= (waveset_excess * avail_sms); - } - } - - // Setup fast-div/mod for device-side usage - div_mod_tiled_shape_m = FastDivmod(tiled_shape.m()); - div_mod_tiled_shape_n = FastDivmod(tiled_shape.n()); - div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); - div_mod_iters_per_tile = FastDivmod(iters_per_tile); - - } - - /// Number of blocks performing useful work - int get_num_active_blocks() const - { - return (sk_waves * avail_sms) + dp_blocks + reduction_blocks; - } - - /// Obtains number of threadblocks per GEMM - int get_num_blocks() const - { - int active_blocks = get_num_active_blocks(); - if (remap_block_indices) - { - // Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves - return fast_max(active_blocks, avail_sms * 4); - } - - return active_blocks; - } - - - /// Obtains grid extents in CTAs - dim3 get_grid_dims() const - { - return dim3(get_num_blocks(), 1, batch_count); - } - - - // - // Device-side interface - // - - /// Obtains number of threadblocks per GEMM - CUTLASS_DEVICE - int device_num_blocks() const - { - return gridDim.x; - } - - /// Obtains tile index for the given sk iteration - CUTLASS_DEVICE - int get_sk_tile_idx(int iter) const - { - int tile_idx = div_mod_iters_per_tile.div(iter); - return tile_idx; - } - - /// Obtains the batch index - CUTLASS_DEVICE - int get_batch_idx() const - { - return RematerializeBlockIdxZ(); - } - - /// Obtains the calling threadblock's tiled coordinates for the given tile index - CUTLASS_DEVICE - GemmCoord get_tile_offset(int tile_idx) const - { - int m, n; - - // row-major raster - div_mod_tiled_shape_n(m, n, tile_idx); - - if (tiled_shape().m() < tiled_shape().n()) - { - // column-major raster - div_mod_tiled_shape_m(n, m, tile_idx); - } - - if (cohort_raster) - { - // tiled cohort raster - int cohort_tile_idx = tile_idx / kCtasPerCohort; - int cohort_grid_m, cohort_grid_n; - div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); - - int block_idx_cohort = tile_idx % kCtasPerCohort; - int block_cohort_m = block_idx_cohort / kCohortCtasN; - int block_cohort_n = block_idx_cohort % kCohortCtasN; - - m = (cohort_grid_m * kCohortCtasM) + block_cohort_m; - n = (cohort_grid_n * kCohortCtasN) + block_cohort_n; - } - - return GemmCoord(m, n, get_batch_idx()); - } - - /// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rasterization) - CUTLASS_DEVICE - GemmCoord get_tile_offset_row_major(int tile_idx) const - { - // row-major raster - int m, n; - div_mod_tiled_shape_n(m, n, tile_idx); - return GemmCoord(m, n, get_batch_idx()); - } - - /// Obtains calling threadblock's linear threadblock index - CUTLASS_DEVICE - int get_block_idx() const - { - int block_idx = RematerializeBlockIdxX(); - - // Remap the block indices for the first two waves of thread blocks if - // we have multi-occupancy and the grid constitutes four or more waves - if (remap_block_indices && (block_idx < avail_sms * 2)) - { - int dest_sm = block_idx / 2; - int dest_wave = block_idx % 2; - int remapped_block_idx = dest_sm + (dest_wave * avail_sms); - block_idx = remapped_block_idx; - } - - // Remap block indices to interleave SK regions to limit intra-region waiting - if (block_idx < sk_regions() * sk_blocks_per_region()) - { - int block_in_region; - int region; - div_mod_sk_regions(block_in_region, region, block_idx); - block_idx = (region * sk_blocks_per_region()) + block_in_region; - } - - return block_idx; - } - - - /// Obtains calling linear threadblock index of the first block to work on the given tile - CUTLASS_DEVICE - int get_sk_block_idx(int iter) const - { - int region_idx; - int iter_in_region; - div_mod_sk_iters_per_region(region_idx, iter_in_region, iter); - - int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks - int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal blocks - - int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region); - int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters); - - int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ? - big_block_idx_in_region : - normal_block_idx_in_region; - - int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region; - - return owning_block_idx; - } - - /// Obtains iteration extends for the given SK block index - CUTLASS_DEVICE - void get_iter_extents( - int sk_block_idx, - int &block_iter_begin, - int &block_iter_end) const - { - int region_idx; - int block_idx_in_region; - div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); - - block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block()); - - // Adjust extents for the first "num_big_blocks" blocks that get one extra iteration - int block_iters = sk_iters_per_normal_block(); - if (block_idx_in_region < sk_big_blocks_per_region) { - // This is a +1 iteration block - block_iter_begin += block_idx_in_region; - block_iters++; - } else { - // This is a regular block - block_iter_begin += sk_big_blocks_per_region; - } - block_iter_end = block_iter_begin + block_iters; - } - - - /// Obtains calling linear threadblock index of the first block to work on the given tile - CUTLASS_DEVICE - int get_first_block_idx(int tile_idx, int block_idx) const - { - if (tile_idx >= sk_tiles) { - // DP tile - return block_idx; - } - - int iter = tile_idx * iters_per_tile(); - return get_sk_block_idx(iter); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h deleted file mode 100644 index 067da30b1901532ffccc69c19906ff6630520f71..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h +++ /dev/null @@ -1,612 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/mma_complex_tensor_op.h" -#include "cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h" -#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Complex transform on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transform on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_ = arch::OpMultiplyAddComplex> -struct DefaultMmaComplexTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex case -// 4 real-valued mma operations -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Real-valued underlying type of complex-valued A operand - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Real-valued underlying type of complex-valued B operand - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Real-valued underlying type of complex-valued C operand - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddComplex> { - - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - RealElementA, - cutlass::layout::RowMajor, - RealElementB, - cutlass::layout::ColumnMajor, - RealElementC, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex case using GaussianComplex operation -// 3 real-valued mma operations -// A = (ar + j ai), B = (br +j bi), D = AB -// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -// D = dr + j di = (P1 - P3) + j (P1 + P2) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Real-valued underlying type of complex-valued A operand - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Real-valued underlying type of complex-valued B operand - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Real-valued underlying type of complex-valued C operand - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddGaussianComplex> { - - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - RealElementA, - cutlass::layout::RowMajor, - RealElementB, - cutlass::layout::ColumnMajor, - RealElementC, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization - input and output types are complex*complex -// Use TF32 tensor operation internally -// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddComplex> { - - // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - tfloat32_t, - cutlass::layout::RowMajor, - tfloat32_t, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization - input and output types are complex*complex -// Use BF16 tensor operation internally -// 4 real-valued mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 operations on BF16 -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddFastBF16> { - - // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 mma instruction - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - bfloat16_t, - cutlass::layout::RowMajor, - bfloat16_t, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization - input and output types are complex*complex -// Use F16 tensor operation internally -// 4 real-valued mma.sync.aligned.m16n8k8.f32.f16.f16.f32 operations on F16 -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddFastF16> { - - // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.f16.f16.f32 mma instruction - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - half_t, - cutlass::layout::RowMajor, - half_t, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// 3xTF32 or 4xTF32 (fast and accurate complex operation) -/// Partial specialization - input and output types are complex * complex -// Use 3xTF32 or 4xTF32 tensor operation internally -// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = 3x[(ar*br - ai*bi) + j (ar*bi + ai*br)] -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - InstructionShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddComplexFastF32> { - - // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - tfloat32_t, - cutlass::layout::RowMajor, - tfloat32_t, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOpFastF32< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex case -// 4 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -// A = (ar + j ai), B (br +j bi), D = AB -// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Real-valued underlying type of complex-valued A operand - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Real-valued underlying type of complex-valued B operand - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Real-valued underlying type of complex-valued C operand - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - GemmShape<16, 8, 4>, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddComplex> { - - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 4>, - 32, - RealElementA, - cutlass::layout::RowMajor, - RealElementB, - cutlass::layout::ColumnMajor, - RealElementC, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB, - true>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for complex*complex case using GaussianComplex operation -// 3 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -// A = (ar + j ai), B = (br +j bi), D = AB -// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -// D = dr + j di = (P1 - P3) + j (P1 + P2) -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Real-valued underlying type of complex-valued A operand - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Real-valued underlying type of complex-valued B operand - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Real-valued underlying type of complex-valued C operand - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB> -struct DefaultMmaComplexTensorOp< - WarpShape_, - GemmShape<16, 8, 4>, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - TransformA, - TransformB, - arch::OpMultiplyAddGaussianComplex> { - - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 4>, - 32, - RealElementA, - cutlass::layout::RowMajor, - RealElementB, - cutlass::layout::ColumnMajor, - RealElementC, - cutlass::layout::RowMajor, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1> - >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< - WarpShape_, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, - Policy, - TransformA, - TransformB, - true>; -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h deleted file mode 100644 index e2cb3f2249c9beabd0e557c96d7361be2e28a133..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h +++ /dev/null @@ -1,165 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/mma_sparse_tensor_op.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Operator describing the tensor operation - typename Operator_ = arch::OpMultiplyAdd, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false -> -struct DefaultSparseMmaTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs and output types are float - uses TF32 internally -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of target matrix multiply instruction (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultSparseMmaTensorOp< - WarpShape_, - InstructionShape_, - float, LayoutA, - float, LayoutB, - float, LayoutC, - arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { - - // Uses TF32 internally - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::SparseMma< - InstructionShape_, - 32, - tfloat32_t, cutlass::layout::RowMajor, - tfloat32_t, cutlass::layout::ColumnMajor, - float, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::SparseMmaTensorOp< - WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Operator describing the tensor operation - typename Operator_, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultSparseMmaTensorOp { - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::SparseMma, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::SparseMmaTensorOp< - WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h deleted file mode 100644 index 44d7fe1155bdd3e60bdc935e9ba48afa7cbf8f84..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Operator describing the tensor operation - typename Operator_ = arch::OpMultiplyAdd, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false> -struct DefaultMmaTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Operator describing the tensor operation - typename Operator_, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp { - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOp< - WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/gemm/warp/default_mma_tensor_op_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h deleted file mode 100644 index 8c9abb8236230edd5787a4422907cef90a525579..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ /dev/null @@ -1,375 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs and output types are float - uses BF16 internally -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - GemmShape<16, 8, 8>, - float, LayoutA, - float, LayoutB, - float, LayoutC, - arch::OpMultiplyAddFastBF16, - PartitionsK, AccumulatorsInRowMajor> { - - // Uses BF16 internally - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 8>, - 32, - bfloat16_t, cutlass::layout::RowMajor, - bfloat16_t, cutlass::layout::ColumnMajor, - float, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOp< - WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs and output types are float - uses F16 internally -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - GemmShape<16, 8, 8>, - float, LayoutA, - float, LayoutB, - float, LayoutC, - arch::OpMultiplyAddFastF16, - PartitionsK, AccumulatorsInRowMajor> { - - // Uses F16 internally - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 8>, - 32, - half_t, cutlass::layout::RowMajor, - half_t, cutlass::layout::ColumnMajor, - float, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOp< - WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs and output types are float - uses TF32 internally -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of target matrix multiply instruction (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - InstructionShape_, - float, LayoutA, - float, LayoutB, - float, LayoutC, - arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { - - // Uses TF32 internally - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - tfloat32_t, cutlass::layout::RowMajor, - tfloat32_t, cutlass::layout::ColumnMajor, - float, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOp< - WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs and output types are float - uses TF32 for Fast Accurate FP32 -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of target matrix multiply instruction (concept: GemmShape) - typename InstructionShape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - InstructionShape_, - float, LayoutA, - float, LayoutB, - float, LayoutC, - arch::OpMultiplyAddFastF32, PartitionsK, AccumulatorsInRowMajor> { - - // Uses TF32 internally - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape_, - 32, - cutlass::tfloat32_t, cutlass::layout::RowMajor, - cutlass::tfloat32_t, cutlass::layout::ColumnMajor, - float, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpFastF32< - WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs are mixed types - uses wider datatype internally. -/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32) -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Element type of A matrix - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Element type of B matrix - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - GemmShape<16, 8, 16>, // InstructionShape - ElementA, // Element type of A matrix in Global Memory - LayoutA, // Layout of A matrix in Global Memory - ElementB, // Element type of B matrix in Global Memory - LayoutB, // Layout of B matrix in Global Memory - ElementC, // Element type of C matrix in Global Memory - LayoutC, // Layout of C matrix in Global Memory - arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype - PartitionsK, AccumulatorsInRowMajor> { - - - // Check if the ElementA and ElementB are of different data types - static_assert(!platform::is_same::value, - "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); - - // Data type used for internal computation - use the wider of the two data types for mma.sync operands - using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), - ElementA, ElementB>::type; - - // Operand datatypes in the internal MMA instruction - use the wider of the two data types - using ElementAMma = ElementOperand; - using ElementBMma = ElementOperand; - using MmaElementC = ElementC; - - // Uses - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 16>, - 32, - ElementAMma, cutlass::layout::RowMajor, - ElementBMma, cutlass::layout::ColumnMajor, - MmaElementC, cutlass::layout::RowMajor, - arch::OpMultiplyAdd - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< - WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization - inputs are mixed types - uses wider datatype internally. -/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32) -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Element type of A matrix - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Element type of B matrix - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp< - WarpShape_, - GemmShape<16, 8, 32>, // InstructionShape - ElementA, // Element type of A matrix in Global Memory - LayoutA, // Layout of A matrix in Global Memory - ElementB, // Element type of B matrix in Global Memory - LayoutB, // Layout of B matrix in Global Memory - ElementC, // Element type of C matrix in Global Memory - LayoutC, // Layout of C matrix in Global Memory - arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype - PartitionsK, AccumulatorsInRowMajor> { - - - // Check if the ElementA and ElementB are of different data types - static_assert(!platform::is_same::value, - "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); - - // Data type used for internal computation - use the wider of the two data types for mma.sync operands - using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), - ElementA, ElementB>::type; - - // Operand datatypes in the internal MMA instruction - use the wider of the two data types - using MmaElementA = ElementOperand; - using MmaElementB = ElementOperand; - using MmaElementC = ElementC; - - // Uses - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - GemmShape<16, 8, 32>, - 32, - MmaElementA, cutlass::layout::RowMajor, - MmaElementB, cutlass::layout::ColumnMajor, - MmaElementC, cutlass::layout::RowMajor, - arch::OpMultiplyAddSaturate - >, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< - WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h deleted file mode 100644 index 7bd8c0fde5f0d3360c9468484ee61721fc9f30e0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h +++ /dev/null @@ -1,92 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/mma_with_reduction_tensor_op.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Operator describing the tensor operation - typename Operator_, - /// Reduce operand A or B along K dimension - bool ReduceKForA_, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false> -struct DefaultMmaWithReductionTensorOp { - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaWithReductionTensorOp< - WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h deleted file mode 100644 index 6a90a780520e888733f74a3d84e447470684c094..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h +++ /dev/null @@ -1,130 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - ///< Size of the Gemm problem (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Operator describing the tensor operation - typename Operator_ = arch::OpMultiplyAdd, - /// Number of partitions along K dimension - int PartitionsK = 1 -> -struct DefaultMmaTensorOpWmma; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template < - ///< Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Operator describing the tensor operation - typename Operator_, - /// Number of partitions along K dimension - int PartitionsK> -struct DefaultMmaTensorOpWmma { - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Wmma< - InstructionShape_, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Operator_>, - cutlass::MatrixShape<1, 1> >; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpWmma< - WarpShape_, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - Policy, - PartitionsK>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -#endif diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h deleted file mode 100644 index f032f26fcac99de781d67d4012e813e920803948..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h +++ /dev/null @@ -1,139 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level per channel scale+bias+relu before - matrix multiply-accumulate operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LayernormScaleBiasTransform { - - using T = typename FragmentActivations::Element; - - static int const NumActivations = FragmentActivations::kElements; - static int const NumVarMean = FragmentVarMean::kElements; - static int const NumGammaBeta = FragmentGammaBeta::kElements; - static int const MmaElements = 2; - // One element has one scale and one bias - static int const MmaScaleBiasPair = 2; - // 16816 has 2 columns and 2 rows - static int const MmaCols = 2; - static int const MmaRows = 2; - - using MmaOperand = Array; - using VarMeanOperand = Array<__half2, MmaScaleBiasPair>; - using GammaBetaOperand = Array; - - CUTLASS_DEVICE - void transform(MmaOperand &activations, - VarMeanOperand const &var_mean, - GammaBetaOperand const &gamma_beta) { - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t *ptr_activations = reinterpret_cast(&activations); - uint32_t const *ptr_var_mean = reinterpret_cast(&var_mean); - uint32_t const *ptr_gamma_beta = reinterpret_cast(&gamma_beta); - - // Apply per channel scale+bias+relu if the data is not a special NaN - // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. - - // We assumes the pair of FP16 are either both inbound or both out-of-bound. - // It requires C to be an even number. - asm volatile( - "{\n\t" - " fma.rn.f16x2 %0, %1, %2, %3;\n" - " fma.rn.f16x2 %0, %4, %0, %5;\n" - "}\n" - : "=r"(ptr_activations[0]) - : "r"(ptr_var_mean[0]), "r"(ptr_activations[0]), - "r"(ptr_var_mean[1]), - "r"(ptr_gamma_beta[0]), "r"(ptr_gamma_beta[1])); -#else - assert(0); -#endif - } - - CUTLASS_DEVICE - void operator()(FragmentActivations &activations, - FragmentVarMean const &var_mean, - FragmentGammaBeta const &gamma_beta) { - MmaOperand *ptr_activations = reinterpret_cast(&activations); - VarMeanOperand const *ptr_var_mean = - reinterpret_cast(&var_mean); - GammaBetaOperand const *ptr_gamma_beta = - reinterpret_cast(&gamma_beta); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < (NumActivations / MmaElements); ++i) { - transform(ptr_activations[i], - ptr_var_mean[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows], - ptr_gamma_beta[(i / MmaScaleBiasPair) % MmaCols]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma.h deleted file mode 100644 index cd67743301140d50d38b27926b56d654168f5fdd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma.h +++ /dev/null @@ -1,60 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for warp-level multiply-add operations -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Query the number of threads per warp -template -struct WarpSize { - static int const value = 32; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h deleted file mode 100644 index e4b7cf0384627299e2ad4e916bc023cf7384e242..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ /dev/null @@ -1,1168 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/functional.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" -#include "cutlass/arch/mma_sm90.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - /// Data type of real & imag members of complex numbers in the SourceFragment - typename RealElement, - /// Destination fragment required by the mma operation - typename DestinationFragment, - /// Source fragment holding complex elements - typename SourceFragment, - /// Number of mma operations performed - typename MmaIterations, - /// Shape of operand elements - typename MmaOperandShape, - /// Complex transform on A operand - ComplexTransform Transform_, - /// Operand A or Operand B - Operand Operand_, - /// Floating-point rounding style - FloatRoundStyle Round_> -struct UnpackComplexConvertAndPackForMma; - -// Partial specialization for OperandA and Congruous smem layout -template < - typename RealElement, - typename DestinationFragment, - typename SourceFragment, - typename MmaIterations, - typename MmaOperandShape, - ComplexTransform Transform_, - FloatRoundStyle Round_> -struct UnpackComplexConvertAndPackForMma < - RealElement, - DestinationFragment, - SourceFragment, - MmaIterations, - MmaOperandShape, - Transform_, - Operand::kA, - Round_> { - - // - // Type definitions - // - static Operand const kOperand = Operand::kA; - static ComplexTransform const kTransform = Transform_; - static FloatRoundStyle const kRound = Round_; - - // Data type of elements in the destination fragment - using MmaElement = typename DestinationFragment::Element; - - // Numeric convertor MmaElement <= RealElement - using Converter = NumericConverter; - - // Operand layout parameters - using SourceFragmentLayout = layout::ColumnMajor; - static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; - - /// Ctor - CUTLASS_DEVICE - UnpackComplexConvertAndPackForMma() {} - - CUTLASS_DEVICE - void operator()(DestinationFragment *dest, SourceFragment const &source) { - - Converter convert_op; - SourceFragmentLayout layout(kLdm); - - CUTLASS_PRAGMA_UNROLL - for(int i=0; i and apply rounding on real and imag parts - MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); - MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest[i][pos] = a; - dest[i+MmaIterations::kRow][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); - - } - } - } - } -}; - -// Partial specialization for OperandB and Congruous smem layout -template < - typename RealElement, - typename DestinationFragment, - typename SourceFragment, - typename MmaIterations, - typename MmaOperandShape, - ComplexTransform Transform_, - FloatRoundStyle Round_> -struct UnpackComplexConvertAndPackForMma < - RealElement, - DestinationFragment, - SourceFragment, - MmaIterations, - MmaOperandShape, - Transform_, - Operand::kB, - Round_> { - - // - // Type definitions - // - static Operand const kOperand = Operand::kB; - static ComplexTransform const kTransform = Transform_; - static FloatRoundStyle const kRound = Round_; - - // Data type of elements in the destination fragment - using MmaElement = typename DestinationFragment::Element; - - // Numeric convertor MmaElement <= RealElement - using Converter = NumericConverter; - - // Operand layout parameters - using SourceFragmentLayout = layout::RowMajor; - static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; - - /// Ctor - CUTLASS_DEVICE - UnpackComplexConvertAndPackForMma() {} - - CUTLASS_HOST_DEVICE - void operator()(DestinationFragment *dest, SourceFragment const &source) { - - Converter convert_op; - SourceFragmentLayout layout(kLdm); - - CUTLASS_PRAGMA_UNROLL - for(int i=0; i apply rounding on real and imag parts - MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); - MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest[i][pos] = a; - dest[i+MmaIterations::kColumn][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); - } - } - } - } -}; -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transform on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Do source operands need more than one elements - bool GeneralizedOperatorElements = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaComplexTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB -> -class MmaComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Indicates math operator - using MathOperator = arch::OpMultiplyAddComplex; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = FragmentB; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued - /// parts are stored consecutively followed by all imaginary parts. This matches the structure - /// of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - static_assert( - FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, - "Unexpected planar complex fragment length."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C - ) const { - - // Alias types for underlying real-valued matrix multiply operator - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert(MmaOperandA::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the A operand." - "We can geneneralize later."); - - static_assert(MmaOperandB::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the B operand." - "We can geneneralize later."); - - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.real(), a.real(), b.real(), accum.real()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = A[m].real(); - operand_B[0] = B[n].real(); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.imag(), a.real(), b.imag(), accum.imag()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = A[m].real(); - operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.real(), -a.imag(), b.imag(), accum.real()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - // A imaginary part is intentionally negated - operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag()); - operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag()); - operand_B[0] = B[n].real(); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex: -// Operands data type: complex -// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -// Output data type: complex -// -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB -> -class MmaComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of members of complex multiplicand A - using RealElementA = float; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of members of complex multiplicand B - using RealElementB = float; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of members of complex accumulator matrix C - using RealElementC = float; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Underlying arch tag - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Indicates math operator - using MathOperator = typename arch::OpMultiplyAddComplex; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = - Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = - Array; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of complex products operations performed (one complex product needs four mma instructions) - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued - /// parts are stored consecutively followed by all imaginary parts. This matches the structure - /// of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - - // Alias types for underlying real-valued matrix multiply operator - using InstMmaOperandA = typename ArchMmaOperator::FragmentA; - using InstMmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, - "This implementation only supports mma.m16n8k8 math instructions."); - - static_assert(InstMmaOperandA::kElements == 4, - "This implementation only supports math instructions in which exactly four element is needed for the A operand." - "We can geneneralize later."); - - static_assert(InstMmaOperandB::kElements == 2, - "This implementation only supports math instructions in which exactly two element is needed for the B operand." - "We can geneneralize later."); - - // Instruction Operands A & B holding real part followed by imaginary part for mma operations - InstMmaOperandA const *operand_A = reinterpret_cast(&A); - InstMmaOperandB const *operand_B = reinterpret_cast(&B); - - // - // Accumulate in place - // - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.real(), a.real(), b.real(), accum.real()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A[m], operand_B[n], *accum); - } - - // mma(accum.imag(), a.real(), b.imag(), accum.imag()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); - } - - // mma(accum.real(), a.imag(), -b.imag(), accum.real()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instructions than negating OperandA as OperandB has less elements - negate negate_op; - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); - } - - // mma(accum.imag(), a.imag(), b.real(), accum.imag()) - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - // Alias types for underlying real-valued matrix multiply operator - using InstMmaOperandA = typename ArchMmaOperator::FragmentA; - using InstMmaOperandB = typename ArchMmaOperator::FragmentB; - - // - // Define conversions from source type to instruction operands' type - // - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 - FloatRoundStyle const kRoundA = FloatRoundStyle::round_to_nearest; - FloatRoundStyle const kRoundB = FloatRoundStyle::round_to_nearest; - #else - FloatRoundStyle const kRoundA = FloatRoundStyle::round_half_ulp_trunc_dntz; - FloatRoundStyle const kRoundB = FloatRoundStyle::round_half_ulp_trunc_dntz; - #endif - - detail::UnpackComplexConvertAndPackForMma < - RealElementA, - InstMmaOperandA, - FragmentA, - MmaIterations, - MatrixShape<2, 2>, - kTransformA, - Operand::kA, - kRoundA> convert_A; - - detail::UnpackComplexConvertAndPackForMma < - RealElementB, - InstMmaOperandB, - FragmentB, - MmaIterations, - MatrixShape<2, 1>, - kTransformB, - Operand::kB, - kRoundB> convert_B; - - // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element - convert_A(reinterpret_cast(&dst_A), A); - convert_B(reinterpret_cast(&dst_B), B); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for complex*complex+complex => complex: -// Operands data type: complex -// Math instruction: mma.sync.aligned.m16n8k4.f64.f64.f64.f64 -// Output data type: complex -// -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB -> -class MmaComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB, - true> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of members of complex multiplicand A - using RealElementA = double; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of members of complex multiplicand B - using RealElementB = double; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of members of complex accumulator matrix C - using RealElementC = double; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Underlying arch tag - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Indicates math operator - using MathOperator = typename arch::OpMultiplyAddComplex; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = FragmentB; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued - /// parts are stored consecutively followed by all imaginary parts. This matches the structure - /// of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - static_assert( - FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, - "Unexpected planar complex fragment length."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C - ) const { - - // Alias types for underlying real-valued matrix multiply operator - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.real(), a.real(), b.real(), accum.real()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.imag(), a.real(), b.imag(), accum.imag()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? - -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.real(), -a.imag(), b.imag(), accum.real()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - // A imaginary part is intentionally negated - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? - A[m*MmaOperandA::kElements + mk].imag() : -A[m*MmaOperandA::kElements + mk].imag()); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? - -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.imag(), a.imag(), b.real(), accum.imag()) - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? - -A[m*MmaOperandA::kElements + mk].imag() : A[m*MmaOperandA::kElements + mk].imag()); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h deleted file mode 100644 index fd90ab8c4252f95fff90c21bfeaf6fb45c4b110b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h +++ /dev/null @@ -1,663 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/functional.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -namespace detail { - -template < - /// Data type of real & imag members of complex numbers in the SourceFragment - typename RealElement, - /// Destination fragment required by the mma operation - typename DestinationFragment, - /// Source fragment holding complex elements - typename SourceFragment, - /// Number of mma operations performed - typename MmaIterations, - /// Shape of operand elements - typename MmaOperandShape, - /// Complex transform on A operand - ComplexTransform Transform_, - /// Operand A or Operand B - Operand Operand_, - /// Floating-point rounding style for big part - FloatRoundStyle RoundBig_, - /// Floating-point rounding style for small part - FloatRoundStyle RoundSmall_> -struct UnpackComplexConvertAndPackForMmaFastF32; - -// Partial specialization for OperandA and Congruous smem layout -template < - typename RealElement, - typename DestinationFragment, - typename SourceFragment, - typename MmaIterations, - typename MmaOperandShape, - ComplexTransform Transform_, - FloatRoundStyle RoundBig_, - FloatRoundStyle RoundSmall_> -struct UnpackComplexConvertAndPackForMmaFastF32 < - RealElement, - DestinationFragment, - SourceFragment, - MmaIterations, - MmaOperandShape, - Transform_, - Operand::kA, - RoundBig_, - RoundSmall_> { - - // - // Type definitions - // - static Operand const kOperand = Operand::kA; - static ComplexTransform const kTransform = Transform_; - static FloatRoundStyle const kRoundBig = RoundBig_; - static FloatRoundStyle const kRoundSmall = RoundSmall_; - - // Data type of elements in the destination fragment - using MmaElement = typename DestinationFragment::Element; - - // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement - using Converter = NumericConverterFastF32; - - // Operand layout parameters - using SourceFragmentLayout = layout::ColumnMajor; - static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; - - // BigSmall Fragment holding two TF32 elements (big, small) for every float - using BigSmallFragment = Array; - - /// Index in fargments for the big and small part - static int const kBigIndex = 0; - static int const kSmallIndex = 1; - - /// Ctor - CUTLASS_DEVICE - UnpackComplexConvertAndPackForMmaFastF32() {} - - CUTLASS_DEVICE - void operator()(DestinationFragment *dest, SourceFragment const &source) { - - Converter convert_op; - SourceFragmentLayout layout(kLdm); - - DestinationFragment *dest_big_ = reinterpret_cast(dest); - DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kRow * 2]); - - CUTLASS_PRAGMA_UNROLL - for(int i=0; i and apply rounding on real and imag parts - BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); - BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest_big_[i][pos] = a[kBigIndex]; - dest_big_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest_small_[i][pos] = a[kSmallIndex]; - dest_small_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); - - // Next position - pos++; - } - } - } - } -}; - -// Partial specialization for OperandB and Congruous smem layout -template < - typename RealElement, - typename DestinationFragment, - typename SourceFragment, - typename MmaIterations, - typename MmaOperandShape, - ComplexTransform Transform_, - FloatRoundStyle RoundBig_, - FloatRoundStyle RoundSmall_> -struct UnpackComplexConvertAndPackForMmaFastF32 < - RealElement, - DestinationFragment, - SourceFragment, - MmaIterations, - MmaOperandShape, - Transform_, - Operand::kB, - RoundBig_, - RoundSmall_> { - - // - // Type definitions - // - static Operand const kOperand = Operand::kB; - static ComplexTransform const kTransform = Transform_; - static FloatRoundStyle const kRoundBig = RoundBig_; - static FloatRoundStyle const kRoundSmall = RoundSmall_; - - // Data type of elements in the destination fragment - using MmaElement = typename DestinationFragment::Element; - - // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement - using Converter = NumericConverterFastF32; - - // Operand layout parameters - using SourceFragmentLayout = layout::RowMajor; - static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; - - // BigSmall Fragment holding two TF32 elements (big, small) for every float - using BigSmallFragment = Array; - - /// Index in fargments for the big and small part - static int const kBigIndex = 0; - static int const kSmallIndex = 1; - - /// Ctor - CUTLASS_DEVICE - UnpackComplexConvertAndPackForMmaFastF32() {} - - CUTLASS_HOST_DEVICE - void operator()(DestinationFragment *dest, SourceFragment const &source) { - - Converter convert_op; - SourceFragmentLayout layout(kLdm); - - DestinationFragment *dest_big_ = reinterpret_cast(dest); - DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kColumn * 2]); - - CUTLASS_PRAGMA_UNROLL - for(int i=0; i apply rounding on real and imag parts - BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); - BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest_big_[i][pos] = a[kBigIndex]; - dest_big_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); - - // Unpack rounded complex and pack into DestinationFragment for mma operation - dest_small_[i][pos] = a[kSmallIndex]; - dest_small_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); - - // next position - pos++; - } - } - } - } -}; -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transform on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Used for partial specialization - typename Enable = bool -> -class MmaComplexTensorOpFastF32; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex: -// Operands data type: complex -// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -// Output data type: complex -// -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB, - /// Used for partial specialization - typename Enable -> -class MmaComplexTensorOpFastF32< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of members of complex multiplicand A - using RealElementA = float; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of members of complex multiplicand B - using RealElementB = float; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of members of complex accumulator matrix C - using RealElementC = float; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Underlying arch tag - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Indicates math operator - using MathOperator = arch::OpMultiplyAddComplexFastF32; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - - /// Tune F32 to TF32 big small conversion for complex operation - /// Different combination of big small conversin can cause different tradeoff - /// between speed and accuracy. Generally, use round_half_ulp_truncate can - /// improve the performance but hur the accuracy. - using ComplexFastF32 = FastF32 < - FloatRoundStyle::round_toward_zero, // kRoundBigA - FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA - FloatRoundStyle::round_toward_zero, // kRoundBigB - FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB - TensorFloat32Op::k3xTF32 // Number of TF32 operations - >; - - /// Index in fargments for the big and small part - static int const kBigIndex = 0; - static int const kSmallIndex = 1; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - // (4 times the original FragmentA::kElements) - // (real_big), (imag_big), (real_small), (imag_small) - using TransformedFragmentA = Array; - - // Fragment bisecting big and small sections - // (real_big, imag_big), (real_small, imag_small) - using AccessTypeFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - // (4 times the original FragmentB::kElements) - // (real_big), (imag_big), (real_small), (imag_small) - using TransformedFragmentB = Array; - - // Fragment bisecting big and small sections - // (real_big, imag_big), (real_small, imag_small) - using AccessTypeFragmentB = Array; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of complex products operations performed (one complex product needs four mma instructions) - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued - /// parts are stored consecutively followed by all imaginary parts. This matches the structure - /// of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - // - // Alias types for underlying real-valued matrix multiply operator - // - using InstMmaOperandA = typename ArchMmaOperator::FragmentA; - using InstMmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, - "This implementation only supports mma.m16n8k8 math instructions."); - - static_assert(InstMmaOperandA::kElements == 4, - "This implementation only supports math instructions in which exactly four element is needed for the A operand." - "We can geneneralize later."); - - static_assert(InstMmaOperandB::kElements == 2, - "This implementation only supports math instructions in which exactly two element is needed for the B operand." - "We can geneneralize later."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaComplexTensorOpFastF32() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - - AccessTypeFragmentA const *complex_A = reinterpret_cast(&A); - AccessTypeFragmentB const *complex_B = reinterpret_cast(&B); - - // - // Accumulate in place - // - D = C; - - - complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kBigIndex], D); - - complex_mma_operator(D, complex_A[kBigIndex], complex_B[kSmallIndex], D); - - complex_mma_operator(D, complex_A[kBigIndex], complex_B[kBigIndex], D); - - if (ComplexFastF32::kPrecision == TensorFloat32Op::k4xTF32) - complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kSmallIndex], D); - } - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void complex_mma_operator( - FragmentC &D, - AccessTypeFragmentA const &complex_A, - AccessTypeFragmentB const &complex_B, - FragmentC const &C - ) const { - - // Instruction Operands A & B holding real part followed by imaginary part for mma operations - InstMmaOperandA const *operand_A = reinterpret_cast(&complex_A); - InstMmaOperandB const *operand_B = reinterpret_cast(&complex_B); - - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.real(), a.real(), b.real(), accum.real()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A[m], operand_B[n], *accum); - } - - // mma(accum.imag(), a.real(), b.imag(), accum.imag()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); - } - - // mma(accum.real(), a.imag(), -b.imag(), accum.real()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instructions than negating OperandA as OperandB has less elements - negate negate_op; - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); - } - - // mma(accum.imag(), a.imag(), b.real(), accum.imag()) - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - detail::UnpackComplexConvertAndPackForMmaFastF32 < - RealElementA, - InstMmaOperandA, - FragmentA, - MmaIterations, - MatrixShape<2, 2>, - kTransformA, - Operand::kA, - ComplexFastF32::kRoundBigA, - ComplexFastF32::kRoundSmallA> convert_A; - - detail::UnpackComplexConvertAndPackForMmaFastF32 < - RealElementB, - InstMmaOperandB, - FragmentB, - MmaIterations, - MatrixShape<2, 1>, - kTransformB, - Operand::kB, - ComplexFastF32::kRoundBigB, - ComplexFastF32::kRoundSmallB> convert_B; - - // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element - convert_A(reinterpret_cast(&dst_A), A); - convert_B(reinterpret_cast(&dst_B), B); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h deleted file mode 100644 index e14450d363f18bbd63ef129b398a97545f29dc95..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h +++ /dev/null @@ -1,2485 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 128b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous128b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 8) && !(Shape::kStrided % 4), "Divisibility."); - - static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous128b; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 1; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<8, 4>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - Shape::kContiguous / Delta::kContiguous, - InstructionShape::kStrided / Delta::kStrided - >; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { - - int quad_pair = lane_id / 8; - int quad = lane_id / 4; - int lane = lane_id % 4; - - int row = (quad & 1) * 4 + (lane ^ quad_pair); - - byte_offset_ = (row + quad_pair * stride_) * sizeof(AccessType); - - pointer_= reinterpret_cast(ref.data()); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - pointer_ += offset; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int offset = - (tile_offset.contiguous() * Shape::kContiguous) + - (tile_offset.strided() * InstructionShape::kStrided * stride_); - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - pointer_ += stride_ * InstructionShape::kStrided; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::Iterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c + - Policy::Delta::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous128b, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous128b, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// -/// Partial specialization for complex -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of underlying field of reals. - typename RealElement, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpAccumulatorTileIterator< - Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = complex; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape; - }; - -private: - - // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire - // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements - // of that row. The accumulators within one row are assumed to be consecutive. - static int const kElementsPerAccess = InstructionShape::kN / 4; - static int const kRowsPerTile = 8; - static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators - /// are stored in a planar complex arrangement with the real parts as entirely contiguous - /// followed by the imaginary parts. - using Fragment = Array; - - static int const kRealIndex = 0; - static int const kImaginaryIndex = Shape::kCount / kThreads; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - - Element z = offset_ref.at({accum_m, accum_n}); - - frag[mma_accum_start + row * kElementsPerAccess + col + kRealIndex] = z.real(); - frag[mma_accum_start + row * kElementsPerAccess + col + kImaginaryIndex] = z.imag(); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - Element z(frag[kRealIndex + idx], frag[kImaginaryIndex + idx]); - - offset_ref.at({accum_m, accum_n}) = z; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 128b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCrosswise128x4, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 8), "Divisibility."); - - static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCrosswise128x4; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 1; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<4, 8>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - InstructionShape::kContiguous / Delta::kContiguous, - Shape::kStrided / Delta::kStrided - >; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { - - int quad = lane_id / 4; - int liq = lane_id % 4; - - int c = liq + (quad & 1) * 4; - int s = (quad / 2); - - byte_offset_ = (c + s * stride_) * sizeof(AccessType); - - pointer_= reinterpret_cast(ref.data()); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - pointer_ += offset; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - // Compute the offset in units of elements. Note, the external coordinate system is - // approximately transposed with respect to the tiled internal structure - int offset = - (tile_offset.contiguous() * InstructionShape::kContiguous) * stride_ + - (tile_offset.strided() * Shape::kStrided); - - add_pointer_offset(offset); - - byte_offset_ ^= (tile_offset.contiguous() & 1) * 4 * sizeof(AccessType); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - pointer_ += stride_ * InstructionShape::kContiguous; - - byte_offset_ ^= 4 * sizeof(AccessType); - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - int access_idx = s + c * Policy::Iterations::kStrided; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c * stride_ + - Policy::Delta::kStrided * s; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * InstructionShape::kContiguous * stride_ + - tile_offset.strided() * Shape::kStrided; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCrosswise128x4, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCrosswise128x4, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Congruous shared memory layout -// Warp-level iterators for complex*complex + complex => complex -// The underlying iterators are similar to that for MMA f64*f64 + f64 = f64 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 64b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, cutlass::complex, - cutlass::layout::TensorOpMultiplicandCongruous64b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 8), "Divisibility."); - - /// Element type - using Element = cutlass::complex; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 2; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<8, 4>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, - InstructionShape::kStrided / Delta::kStrided - >; - - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), - k_group_idx_(0) { - - int access_strided = lane_id / Policy::Delta::kContiguous; - int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; - - pointer_= reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int offset = - (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + - tile_offset.contiguous() * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - add_tile_offset({0, 1}); - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - add_tile_offset({0, -1}); - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::Iterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c + - Policy::Delta::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Crosswise shared memory layout -// Warp-level iterators for complex*complex + complex => complex -// The underlying iterators are similar to that for f64*f64 + f64 = f64 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 64b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, complex, - cutlass::layout::TensorOpMultiplicand64bCrosswise, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); - - static_assert(sizeof_bits>::value == 64, "This is specialized for 64b accesses."); - - /// Element type - using Element = complex; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 2; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<4, 16>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - InstructionShape::kContiguous / Delta::kContiguous, - Shape::kStrided / Delta::kStrided - >; - - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Internal counter for tracking K-group - Index k_group_idx_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), - k_group_idx_(0) { - - int access_strided = lane_id / 8; - int access_contiguous = (lane_id % 8); - - byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); - - pointer_= reinterpret_cast(ref.data()); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - pointer_ += offset / kElementsPerAccess; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * - stride_ * kElementsPerAccess + - tile_offset.strided() * Shape::kStrided; - - add_pointer_offset(offset); - - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - - add_tile_offset(tile_offset); - - if (k_group_idx_ & 1) - byte_offset_ ^= 0x40; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - pointer_ += stride_ * InstructionShape::kContiguous; - - // xor ptr - byte_offset_ ^= 0x40; - - ++k_group_idx_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - int access_idx = c * Policy::Iterations::kStrided + s; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c * stride_ + - Policy::Delta::kStrided * s / kElementsPerAccess; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - - Element *exchange_ptr = reinterpret_cast(&frag); - - // exchange on 64b granularity only for fragments held in k=8/2 to k=8 - CUTLASS_PRAGMA_UNROLL - for (int i = Fragment::kElements/2; i < Fragment::kElements; i += 2) { - Element tmp = exchange_ptr[i]; - exchange_ptr[i] = exchange_ptr[i + 1]; - exchange_ptr[i + 1] = tmp; - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = tile_offset.contiguous() * - InstructionShape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - k_group_idx_ = k_group; - } -}; - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h deleted file mode 100644 index 6728ac2010bc84e7a4edfcae956905e2432e56f3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h +++ /dev/null @@ -1,642 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transform on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Do source operands need more than one elements - bool GeneralizedOperatorElements = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaGaussianComplexTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB -> -class MmaGaussianComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Underlying arch tag - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Indicates math operator - using MathOperator = arch::OpMultiplyAddGaussianComplex; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = FragmentB; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is - /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively - /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - static_assert( - FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, - "Unexpected gaussian complex fragment length."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaGaussianComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C - ) const { - - // Alias types for underlying real-valued matrix multiply operator - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert(MmaOperandA::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the A operand." - "We can geneneralize later."); - - static_assert(MmaOperandB::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the B operand." - "We can geneneralize later."); - - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Asum; - MmaOperandB operand_Br; - - operand_Asum[0] = A[m].real() + ((kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag()); - operand_Br[0] = B[n].real(); - - // accumulator part1 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_Asum, operand_Br, *accum); - } - - // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Ar; - MmaOperandB operand_Bdiff; - - operand_Ar[0] = -A[m].real(); - operand_Bdiff[0] = B[n].real() - ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); - - // accumulator part2 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_Ar, operand_Bdiff, *accum); - } - - // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Ai; - MmaOperandB operand_Bsum; - - operand_Ai[0] = (kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag(); - operand_Bsum[0] = B[n].real() + ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); - - // accumulator part3 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; - - mma(*accum, operand_Ai, operand_Bsum, *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB -> -class MmaGaussianComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB, - true> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Underlying arch tag - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Indicates math operator - using MathOperator = arch::OpMultiplyAddGaussianComplex; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = FragmentB; - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Shape::kM / ArchMmaOperator::Shape::kM, - Shape::kN / ArchMmaOperator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is - /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively - /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - static_assert( - FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, - "Unexpected gaussian complex fragment length."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaGaussianComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C - ) const { - - // Alias types for underlying real-valued matrix multiply operator - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Asum; - MmaOperandB operand_Br; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_Asum[mk] = A[m*MmaOperandA::kElements + mk].real() + ((kTransformA == ComplexTransform::kConjugate) ? - -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag()); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_Br[nk] = B[n*MmaOperandB::kElements + nk].real(); - - // accumulator part1 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_Asum, operand_Br, *accum); - } - - // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Ar; - MmaOperandB operand_Bdiff; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_Ar[mk] = -A[m*MmaOperandA::kElements + mk].real(); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_Bdiff[nk] = B[n*MmaOperandB::kElements + nk].real() - ((kTransformB == ComplexTransform::kConjugate) ? - -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); - - // accumulator part2 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_Ar, operand_Bdiff, *accum); - } - - // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_Ai; - MmaOperandB operand_Bsum; - - CUTLASS_PRAGMA_UNROLL - for (int mk = 0; mk < MmaOperandA::kElements; ++mk) - operand_Ai[mk] = (kTransformA == ComplexTransform::kConjugate) ? - -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag(); - - CUTLASS_PRAGMA_UNROLL - for (int nk = 0; nk < MmaOperandB::kElements; ++nk) - operand_Bsum[nk] = B[n*MmaOperandB::kElements + nk].real() + ((kTransformB == ComplexTransform::kConjugate) ? - -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); - - // accumulator part3 - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; - - mma(*accum, operand_Ai, operand_Bsum, *accum); - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h deleted file mode 100644 index ec99c77f4916e2040cef9fc724c431b0c1531f23..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h +++ /dev/null @@ -1,390 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpGaussianComplexAccumulatorTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// -/// Partial specialization for complex -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of underlying field of reals. - typename RealElement, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpGaussianComplexAccumulatorTileIterator< - Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = complex; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape; - }; - -private: - - // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire - // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements - // of that row. The accumulators within one row are assumed to be consecutive. - static int const kElementsPerAccess = InstructionShape::kN / 4; - static int const kRowsPerTile = 8; - static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators - /// are stored in a gaussian complex arrangement with parts 1, 2, and 3 as entirely contiguous - /// arranged as [part1, part2, part3] - using Fragment = Array; - - static int const kPart1Index = (Shape::kCount / kThreads) * 0; - static int const kPart2Index = (Shape::kCount / kThreads) * 1; - static int const kPart3Index = (Shape::kCount / kThreads) * 2; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpGaussianComplexAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - - Element z = offset_ref.at({accum_m, accum_n}); - - frag[mma_accum_start + row * kElementsPerAccess + col + kPart1Index] = z.real() + z.imag(); - frag[mma_accum_start + row * kElementsPerAccess + col + kPart2Index] = -z.real(); - frag[mma_accum_start + row * kElementsPerAccess + col + kPart3Index] = z.imag(); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - Element z(frag[kPart1Index + idx] - frag[kPart3Index + idx], - frag[kPart1Index + idx] + frag[kPart2Index + idx]); - - offset_ref.at({accum_m, accum_n}) = z; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h deleted file mode 100644 index b07575050ac2999cdcbeb0d4e8a64bfb63214cff..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ /dev/null @@ -1,566 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -//////////////////////////////////////////////////////////////////////////////// -// Shuffle registers for layout conversion -//////////////////////////////////////////////////////////////////////////////// -template < - /// Element type for the operand in registers for the mma.sync - typename ElementMma_, - /// Element type for the operand in shared memory for ldmatrix - typename ElementLoad_, - /// Number of mma.sync operations performed along rows or columns - int NumMmaInstructions, - /// Number of elements in warp fragment - int NumElementsInWarpFragment, - /// Number of elements in mma fragment - int NumElementsInMmaFragment, - /// Identifies A or B multiplicand - Operand Operand_, - /// - typename Enable = void > -struct FragmentShuffler { - public: - using ElementMma = ElementMma_; - using ElementLoad = ElementLoad_; - - static int const kNumMmaInstructions = NumMmaInstructions; - static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; - static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; - static Operand const kOperand = Operand_; - - using WarpFragment = Array; - using MmaFragment = Array; - - CUTLASS_DEVICE - WarpFragment operator()(WarpFragment const &src) { - return src; - } -}; -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) -/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4) -/// for operand A multiplicand going through upcasting. -template < - /// Element type for the operand in registers for the mma.sync - typename ElementMma_, - /// Element type for the operand in shared memory for ldmatrix - typename ElementLoad_, - /// Number of mma.sync operations performed along rows or columns - int NumMmaInstructions, - /// Number of elements in warp fragment - int NumElementsInWarpFragment, - /// Number of elements in mma fragment - int NumElementsInMmaFragment -> -struct FragmentShuffler ::value / - sizeof_bits::value == 2)>::type> { -public: - using ElementMma = ElementMma_; - using ElementLoad = ElementLoad_; - - static int const kNumMmaInstructions = NumMmaInstructions; - static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; - static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; - static Operand const kOperand = Operand::kA; - - using WarpFragment = Array; - using MmaFragment = Array; - - static uint32_t const kSelectBytesEvenThread = 0x5410; - static uint32_t const kSelectBytesOddThread = 0x7632; - -private: - int delta_up_; - int delta_down_; - int odd_even_lane_id_; - uint32_t byte_selector_; - -public: - CUTLASS_DEVICE - FragmentShuffler() { - int lane_id = cutlass::arch::LaneId(); - delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1); - delta_down_ = 2 - delta_up_; - odd_even_lane_id_ = static_cast(lane_id & 1); - byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread + - (1 - odd_even_lane_id_) * kSelectBytesEvenThread; - } - - CUTLASS_DEVICE - WarpFragment operator()(WarpFragment const &src) { - - WarpFragment result; - MmaFragment const* mma_frag_src_ptr = reinterpret_cast(&src); - MmaFragment* mma_frag_dst_ptr = reinterpret_cast(&result); - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kNumMmaInstructions; n++) { - - uint32_t const* src_ptr = reinterpret_cast(&mma_frag_src_ptr[n]); - uint32_t *dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[n]); - - // Shuffle data within the warp, pull from other threads within the warp - uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_); - uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_); - uint32_t tmp2 = __shfl_up_sync(0xFFFFFFFF, src_ptr[1], delta_up_); - uint32_t tmp3 = __shfl_down_sync(0xFFFFFFFF, src_ptr[1], delta_down_); - - // Reorder the data within the 32-bit word (4x8b) required for mma.sync - dst_ptr[0] = __byte_perm(tmp0, tmp2, byte_selector_); - dst_ptr[1] = __byte_perm(tmp1, tmp3, byte_selector_); - } - - return result; - } - -}; -//////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) -/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4) -/// for operand B multiplicand going through upcasting. -template < - /// Element type for the operand in registers for the mma.sync - typename ElementMma_, - /// Element type for the operand in shared memory for ldmatrix - typename ElementLoad_, - /// Number of mma.sync operations performed along rows or columns - int NumMmaInstructions, - /// Number of elements in warp fragment - int NumElementsInWarpFragment, - /// Number of elements in mma fragment - int NumElementsInMmaFragment -> -struct FragmentShuffler ::value / - sizeof_bits::value == 2)>::type> { -public: - using ElementMma = ElementMma_; - using ElementLoad = ElementLoad_; - - static int const kNumMmaInstructions = NumMmaInstructions; - static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; - static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; - static Operand const kOperand = Operand::kB; - - using WarpFragment = Array; - using MmaFragment = Array; - - static uint32_t const kSelectBytesEvenThread = 0x5410; - static uint32_t const kSelectBytesOddThread = 0x7632; - -private: - int delta_up_; - int delta_down_; - int odd_even_lane_id_; - uint32_t byte_selector_; - -public: - CUTLASS_DEVICE - FragmentShuffler() { - int lane_id = cutlass::arch::LaneId(); - delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1); - delta_down_ = 2 - delta_up_; - odd_even_lane_id_ = static_cast(lane_id & 1); - byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread + - (1 - odd_even_lane_id_) * kSelectBytesEvenThread; - } - - CUTLASS_DEVICE - WarpFragment operator()(WarpFragment const &src) { - - WarpFragment result; - - MmaFragment const* mma_frag_src_ptr = reinterpret_cast(&src); - MmaFragment* mma_frag_dst_ptr = reinterpret_cast(&result); - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kNumMmaInstructions; n++) { - - uint32_t const* src_ptr = reinterpret_cast(&mma_frag_src_ptr[n]); - uint32_t* dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[n]); - - // Shuffle data within the warp, pull from other threads within the warp - uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_); - uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_); - - // Reorder the data within the 32-bit word (4x8b) required for mma.sync - dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_); - } - - return result; - } - -}; - -//////////////////////////////////////////////////////////////////////////////// -// Data type conversion -//////////////////////////////////////////////////////////////////////////////// -template < - /// Destination type - typename ElementDst_, - /// Source type - typename ElementSrc_, - /// Number of elements - int N, - /// - typename Enable = void> -struct FragmentConverter { - - using ElementDst = ElementDst_; - using ElementSrc = ElementSrc_; - - // Operand fragment registers in destination and source types - using DestinationFragment = Array; - using SourceFragment = Array; - - FastNumericArrayConverter convert; - - CUTLASS_DEVICE - DestinationFragment operator()(SourceFragment const &src) const { - return convert(src); - } -}; -//////////////////////////////////////////////////////////////////////////////// - -// Partial specialization for when Destination type is the *same* as -// Source type -template < - /// Data type - typename Element, - /// Number of elements - int N, - /// - typename Enable> -struct FragmentConverter { - - using DestinationFragment = Array; - using SourceFragment = Array; - - CUTLASS_DEVICE - DestinationFragment operator()(SourceFragment const &src) const { - return src; - } -}; - -} // namespace detail - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaMixedInputTensorOp { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Underlying arch::Mma instruction datatype for A operand - using ElementAMma = typename ArchMmaOperator::ElementA; - - /// Underlying arch::Mma instruction datatype for B operand - using ElementBMma = typename ArchMmaOperator::ElementB; - - /// Underlying arch::Mma instruction datatype for C operand - using MmaElementC = typename ArchMmaOperator::ElementC; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// - // static int const kLoadShapeK = InstructionShape::kK * - // (sizeof_bits::value / sizeof_bits::value); - -public: - - /// Iterates over the A operand in Shared Memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile in registers (loaded from Shared Memory) - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile in registers (for use in Mma instruction) - using TransformedFragmentA = - Array; - - /// Underlying arch::Mma instruction operand fragment for matrix A - using MmaOperandA = typename ArchMmaOperator::FragmentA; - - /// Iterates over the B operand in Shared Memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for B tile in registers (loaded from Shared Memory) - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile in registers (for use in Mma instruction) - using TransformedFragmentB = - Array; - - /// Underlying arch::Mma instruction operand fragment for matrix B - using MmaOperandB = typename ArchMmaOperator::FragmentB; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Underlying arch::Mma instruction operand fragment for matrix C - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaMixedInputTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - - D = C; - - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } - } - - /// Transform the operand warp fragment register to the required data types and layout - /// for the `cultass::arch::Mma` - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - // Shuffle data within warp to obtain the mma.sync operand layout - detail::FragmentShuffler shuffler_B; - FragmentB tmp_B; - tmp_B = shuffler_B(B); - - // Convert the B operand to the Mma Instruction operand type - detail::FragmentConverter convert_B; - dst_B = convert_B(tmp_B); - - FragmentA tmp_A; - - Array * - ptr_tmp_A = reinterpret_cast *>(&tmp_A); - Array * - ptr_dst_A = reinterpret_cast *>(&dst_A); - - // Shuffle data within warp to obtain the mma.sync operand layout - detail::FragmentShuffler shuffler_A; - - // Convert the A operand to the Mma Instruction operand type - detail::FragmentConverter convert_A; - - tmp_A = shuffler_A(A); - ptr_dst_A[0] = convert_A(ptr_tmp_A[0]); - - ptr_dst_A[1] = convert_A(ptr_tmp_A[1]); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h deleted file mode 100644 index af1031adb4a9e393135075a9a65553d8d7e17102..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h +++ /dev/null @@ -1,182 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/array_planar_complex.h" -#include "cutlass/gemm/warp/tile_iterator_planar_complex.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Underlying real-valued warp-level matrix multiply - typename Operator_, - /// Transformation applied to A operand (typically folded into math instruction) - ComplexTransform TransformA = ComplexTransform::kNone, - /// Transformation applied to B operand (typically folded into math instruction) - ComplexTransform TransformB = ComplexTransform::kNone -> -class MmaPlanarComplex { -public: - - /// Underlying real-valued warp-level matrix multiply - using Operator = Operator_; - - /// Shape of warp-level matrix multipy - using Shape = typename Operator::Shape; - - /// Transformation applied to A operand (typically folded into math instruction) - static ComplexTransform const kTransformA = TransformA; - - /// Transformation applied to B operand (typically folded into math instruction) - static ComplexTransform const kTransformB = TransformB; - - /// Fragment of elements - using FragmentA = ArrayPlanarComplex; - - /// Iterator into planar complex - using IteratorA = TileIteratorPlanarComplex; - - /// Layout in memory of the A operand - using LayoutA = typename Operator::LayoutA; - - using FragmentB = ArrayPlanarComplex; - - /// Iterator into planar complex - using IteratorB = TileIteratorPlanarComplex; - - /// Layout in memory of the B operand - using LayoutB = typename Operator::LayoutB; - - /// Tile iterator for accumulator - using IteratorC = TileIteratorPlanarComplex; - - /// Accumulator fragment - using FragmentC = ArrayPlanarComplex; - - /// Layout of accumulator fragment in memory - using LayoutC = typename Operator::LayoutC; - -private: - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Operator::Shape::kM / Operator::Policy::Operator::Shape::kM, - Operator::Shape::kN / Operator::Policy::Operator::Shape::kN - >; - -public: - /// Ctor - CUTLASS_DEVICE - MmaPlanarComplex() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A_in, - FragmentB const &B_in, - FragmentC const &C) const { - - D.real = C.real; - D.imag = C.imag; - - // - // Transform fragments based on conjugate operations. - // - - negate neg_A; - - FragmentA frag_A; - frag_A.real = A_in.real; - - if (kTransformA == ComplexTransform::kConjugate) { - frag_A.imag = neg_A(frag_A.imag); - } - else { - frag_A.imag = frag_A.imag; - } - - FragmentB frag_B; - frag_B.real = B_in.real; - - if (kTransformB == ComplexTransform::kConjugate) { - negate neg; - frag_B.imag = neg(frag_B.imag); - } - else { - frag_B.imag = frag_B.imag; - } - - // - // Accumulated real-valued matrix multiplies - // - - Operator real_mma; - - // D.i += A.i * B.r - real_mma(D.imag, frag_A.imag, frag_B.real, D.imag); - - // D.r += A.r * B.r - real_mma(D.real, frag_A.real, frag_B.real, D.real); - - // D.i += A.r * B.i - real_mma(D.imag, frag_A.real, frag_B.imag, D.imag); - - // D.r += -A.i * B.i - frag_A.imag = neg_A(frag_A.imag); - real_mma(D.real, frag_A.imag, frag_B.imag, D.real); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt.h deleted file mode 100644 index c4152da36fe767dcbad2faca27ca22e282b6b0c5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt.h +++ /dev/null @@ -1,263 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/thread/mma.h" - -#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -#include "cutlass/gemm/warp/mma_simt_policy.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK = 1, - /// Complex transformation on operand A - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transformation on operand B - ComplexTransform TransformB = ComplexTransform::kNone, - /// Used for partial specialization - typename Enable = bool -> -class MmaSimt { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassSimt; - - /// Hard-coded for now - using ArchTag = arch::Sm50; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Layout of threads - using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value, - layout::ColumnMajor, - typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value, - layout::RowMajor, - LayoutA>::type - >::type; - - using ThreadLayoutB = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutB >::value, - layout::ColumnMajor, - typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value, - layout::RowMajor, - LayoutB>::type - >::type; - - static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || - platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && - platform::is_same< ElementA, int8_t >::value && - platform::is_same< ElementB, int8_t >::value; - - using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; - - /// Thread-level matrix multiply accumulate operator - using ThreadMma = thread::Mma< - GemmShape< - Shape::kM / Policy::WarpShape::kRow, - Shape::kN / Policy::WarpShape::kColumn, - Policy::LaneMmaShape::kK>, - ElementA, - ThreadLayoutA, - ElementB, - ThreadLayoutB, - ElementC, - LayoutC, - arch::OpMultiplyAdd, - dp4a_type - >; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Shape of the underlying instruction - using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaSimtTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - Policy, - PartitionsK, - Shape::kK - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = FragmentA; - - /// Iterates over the B operand in memory - using IteratorB = MmaSimtTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - Policy, - PartitionsK, - Shape::kK - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentB = FragmentB; - - /// Iterates over the C operand in memory - using IteratorC = MmaSimtTileIterator< - MatrixShape, - Operand::kC, - ElementC, - LayoutC, - Policy - >; - - /// Storage for C tile - using FragmentC = typename ThreadMma::FragmentC; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaSimt() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &d, - FragmentA a, - FragmentB b, - FragmentC const &c, int group_idx = 0) const { - - ThreadMma mma; - - if (kTransformA == ComplexTransform::kConjugate) { - conjugate conj_a; - a = conj_a(a); - } - - if (kTransformB == ComplexTransform::kConjugate) { - conjugate conj_b; - b = conj_b(b); - } - - mma(d, a, b, c); - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - dst_A = A; - dst_B = B; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h deleted file mode 100644 index 9bca2348e89a3877ab517a833ba2084cc2f5abb5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h +++ /dev/null @@ -1,69 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT - instructions -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Describes the arrangement and configuration of per-lane operations in warp-level matrix multiply -template < - typename WarpShape_, ///< shape of the warp in lanes (concept: MatrixShape) - typename LaneLayout_, ///< layout function of lanes - typename LaneMmaShape_ ///< size of each lane's thread-level matrix product (concept: GemmShape) -> -struct MmaSimtPolicy { - using WarpShape = WarpShape_; - using LaneLayout = LaneLayout_; - using LaneMmaShape = LaneMmaShape_; - using MmaShape = LaneMmaShape; - - /// Returns a layout functor mapping lane position in the warp to thread ID - CUTLASS_HOST_DEVICE - static LaneLayout get_lane_layout() { - return LaneLayout::packed({WarpShape::kRow, WarpShape::kColumn}); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h deleted file mode 100644 index c522eafa5ef5aa6fff18a196e27d777e05dd753e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h +++ /dev/null @@ -1,1890 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT - instructions -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" - -#include "cutlass/layout/matrix.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma_simt_policy.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -/// -/// concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK = 1, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize = 1 -> -class MmaSimtTileIterator; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for A operands of column-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::ColumnMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kRow % Policy::WarpShape::kRow), - "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow / Policy::WarpShape::kRow, - Shape::kColumn - >; - - static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kM, - ThreadShape::kColumn - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Internal reference - cutlass::TensorRef, layout::ColumnMajor> ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, 0); - - ref.add_coord_offset(lane_offset); - - ref_.reset( - reinterpret_cast *>(ref.data()), - ref.stride(0) / Policy::LaneMmaShape::kM); - } - - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow / Policy::LaneMmaShape::kM, - coord.column() * Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({0, Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({0, -Shape::kColumn}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - - // This logic has been replaced with calls to inline PTX to guarantee vectorization. - #if 0 - dst_ptr[m + k * Iterations::kRow] = - *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM); - #endif - - auto ptr = ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM; - arch::shared_load(dst_ptr[m + k * Iterations::kRow], ptr); - } - } - } - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - Array const *src_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kN; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kM; ++m) { - *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = - src_ptr[m + k * Iterations::kM]; - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for A operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - used in sliced-K - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kRow % Policy::WarpShape::kRow), - "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow / Policy::WarpShape::kRow, - Shape::kColumn - >; - - static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads (scalar loads) - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kM, - ThreadShape::kColumn - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Internal reference - cutlass::TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to conditionally enable extents checking - bool divisible_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() : divisible_(true) { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ) : extent_(Shape::kRow, Shape::kColumn), divisible_ (true) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, 0); - - origin_ = lane_offset; - - ref.add_coord_offset(lane_offset); - - ref_.reset(ref.data(), ref.stride(0)); - - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - TensorCoord extent, - int lane_id - ) : extent_(extent), divisible_ (false) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, 0); - - origin_ = lane_offset; - - ref.add_coord_offset(lane_offset); - - ref_.reset(ref.data(), ref.stride(0)); - - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - TensorCoord coord_offset( - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn); - - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({0, Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({0, -Shape::kColumn}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { - - MatrixCoord offset(m * Policy::WarpShape::kRow * Policy::LaneMmaShape::kM + i, k); - - MatrixCoord access_coord = origin_ + offset; - - int frag_idx = m * Policy::LaneMmaShape::kM + i + k * Iterations::kRow; - - if (divisible_ || - (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { - - frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); - } - else { - frag[frag_idx] = Element(); - } - } - } - } - } - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { - - *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM * Policy::LaneMmaShape::kM + i, k) + pointer_offset) = - frag[m * Policy::LaneMmaShape::kM + i + k * Iterations::kM]; - } - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for B operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), - "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); - static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow, - Shape::kColumn / Policy::WarpShape::kColumn - >; - - static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -protected: - - /// Internal reference - cutlass::TensorRef, layout::RowMajor> ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(0, Policy::LaneMmaShape::kN); - - ref.add_coord_offset(lane_offset); - - ref_.reset( - reinterpret_cast *>(ref.data()), - ref.stride(0) / Policy::LaneMmaShape::kN); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({Shape::kRow, 0}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({-Shape::kRow, 0}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - #if 0 - dst_ptr[n + k * Iterations::kColumn] = - *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN); - #endif - - void const *ptr = ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN; - arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - Array const *src_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kM; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kN; ++n) { - *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = - src_ptr[n + k * Iterations::kN]; - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, Index pointer_offset) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for B operands of column-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK, - /// Group Size along kPartition - used in sliced-K - int PartitionGroupSize -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::ColumnMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), - "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); - static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow, - Shape::kColumn / Policy::WarpShape::kColumn - >; - - static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Internal reference - cutlass::TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to conditionally enable extents checking - bool divisible_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator(): divisible_(true) { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ): extent_(Shape::kRow, Shape::kColumn), divisible_(true) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(0, Policy::LaneMmaShape::kN); - - origin_ = lane_offset; - - ref.add_coord_offset(lane_offset); - - ref_.reset(ref.data(), ref.stride(0)); - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - TensorCoord extent, - int lane_id - ): extent_(extent), divisible_(false) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(0, Policy::LaneMmaShape::kN); - - origin_ = lane_offset; - - ref.add_coord_offset(lane_offset); - - ref_.reset(ref.data(), ref.stride(0)); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - TensorCoord coord_offset( - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn); - - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({Shape::kRow, 0}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({-Shape::kRow, 0}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Policy::LaneMmaShape::kN; ++i) { - - MatrixCoord offset(k, n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + i); - - MatrixCoord access_coord = origin_ + offset; - - int frag_idx = n * Policy::LaneMmaShape::kN + i + k * Iterations::kColumn; - - if (divisible_ || - (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { - - frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); - } - else { - frag[frag_idx] = Element(); - } - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - Array const *src_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kM; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kN; ++n) { - *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = - src_ptr[n + k * Iterations::kN]; - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, Index pointer_offset) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for C operands of column-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_ -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of accumulators in memory - using Layout = layout::ColumnMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert( - (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), - "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); - - /// Thraed-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow / Policy::WarpShape::kRow, - Shape::kColumn / Policy::WarpShape::kColumn - >; - - static_assert( - (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), - "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kM, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - using Delta = MatrixShape< - Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, - Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({Shape::kRow, 0}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({-Shape::kRow, 0}); - - return *this; - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_HOST_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to be loaded from memory - Index pointer_offset) const { ///< linear offset (in units of Element) when loading - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { - - Array const *src_ptr = - reinterpret_cast const *>( - ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n})); - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) { - - Array *dst_ptr = - reinterpret_cast *>(&frag) + - mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN); - - *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM]; - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { - - Array *dst_ptr= - reinterpret_cast *>( - ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n})); - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { - - Array const *src_ptr = - reinterpret_cast const *>(&frag) + - mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN); - - dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr; - } - } - } - } - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for C operands of row-major layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_ -> -class MmaSimtTileIterator { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of accumulators in memory - using Layout = layout::RowMajor; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - // - // Derived quantities - // - - static_assert( - (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), - "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); - - /// Thraed-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow / Policy::WarpShape::kRow, - Shape::kColumn / Policy::WarpShape::kColumn - >; - - static_assert( - (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), - "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kM, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - using Delta = MatrixShape< - Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, - Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - ref_.add_coord_offset({Shape::kRow, 0}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({-Shape::kRow, 0}); - - return *this; - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_HOST_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to be loaded from memory - Index pointer_offset) const { ///< linear offset (in units of Element) when loading - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { - - Array const *src_ptr = - reinterpret_cast const *>( - ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { - - Array *dst_ptr = - reinterpret_cast *>(&frag) + - mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); - - *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn]; - } - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { - - Array *dst_ptr = - reinterpret_cast *>( - ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { - - Array const *src_ptr = - reinterpret_cast const *>(&frag) + - mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); - - dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr; - } - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for A operands of column-major-K interleaved layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK, - /// Number of KGroups per kPartition - int PartitionGroupSize -> -class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::ColumnMajorInterleaved<4> ; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Iterleave factor - static const int kInterleave = 4; - - /// Number of partitions along K dimension - static const int kPartitionsK = PartitionsK; - - /// Number of KGroups per kPartition - static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn; - - // - // Derived quantities - // - - static_assert(!(Shape::kRow % Policy::WarpShape::kRow), - "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); - static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow / Policy::WarpShape::kRow, - Shape::kColumn - >; - - static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kM, - ThreadShape::kColumn / Policy::LaneMmaShape::kK - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Internal reference - cutlass::TensorRef, layout::ColumnMajorInterleaved<4>> ref_; - - /// group index within tile - int k_group_idx_; - -public: - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(Policy::LaneMmaShape::kM, 0); - - ref.add_coord_offset(lane_offset); - - k_group_idx_ = 0; - ref_.reset(reinterpret_cast *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK); - } - - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK, - coord.column() * Shape::kColumn}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == kGroupPerTile) { - k_group_idx_ = 0; - add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)}); - } - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({0, -Shape::kColumn}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - - dst_ptr[m + k * Iterations::kRow] = - *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave, - k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM)); - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - Array const *src_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kN; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kM; ++m) { - *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = - src_ptr[m + k * Iterations::kM]; - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization for B operands of row-major k-interleaved layouts -/// -/// Concept: MutableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Shape of the warp in units of thread (concept: MmaSimtPolicy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK, - /// Number of KGroups per kPartition - int PartitionGroupSize -> -class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of policy - using Layout = layout::RowMajorInterleaved<4>; - - /// Decomposition of elements among threads - using Policy = Policy_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Interleave factor - static const int kInterleave = 4; - - /// Number of partitions along K dimension - static const int kPartitionsK = PartitionsK; - - /// Number of KGroups per kPartition - static const int kGroupPerTile = PartitionGroupSize / Shape::kRow; - - // - // Derived quantities - // - - static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), - "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); - - static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); - static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); - static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); - static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); - - /// Thread-level shape of a fragment - using ThreadShape = MatrixShape< - Shape::kRow, - Shape::kColumn / Policy::WarpShape::kColumn - >; - - static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK), - "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); - - /// Number of individual loads - using Iterations = MatrixShape< - ThreadShape::kRow / Policy::LaneMmaShape::kK, - ThreadShape::kColumn / Policy::LaneMmaShape::kN - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - - -private: - - /// Internal reference - cutlass::TensorRef, layout::RowMajorInterleaved<4>> ref_; - - /// group index within tile - int k_group_idx_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaSimtTileIterator( - TensorRef ref, - int lane_id - ) { - - // compute offset based on thread ID and lane layout - typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); - - MatrixCoord lane_offset = lane_layout.inverse(lane_id) * - MatrixCoord(0, Policy::LaneMmaShape::kN); - - ref.add_coord_offset(lane_offset); - - k_group_idx_ = 0; - - ref_.reset( - reinterpret_cast *>(ref.data()), - ref.stride(0) / Policy::LaneMmaShape::kKN); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - - ref_.add_coord_offset({ - coord.row() * Shape::kRow, - coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator++() { - - add_tile_offset({1, 0}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == kGroupPerTile) { - k_group_idx_ = 0; - add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0}); - } - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaSimtTileIterator & operator--() { - - ref_.add_coord_offset({-Shape::kRow, 0}); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - Array *dst_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - dst_ptr[n + k * Iterations::kColumn] = - *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK, - n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN); - } - } - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - Array const *src_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kM; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kN; ++n) { - *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = - src_ptr[n + k * Iterations::kN]; - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, Index pointer_offset) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h deleted file mode 100644 index 902a3d10674c99428ed36404dbdbc27555fc46a7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h +++ /dev/null @@ -1,382 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate - operations targeting sparse Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> -class SparseMmaTensorOp { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Equivalent base dense mma - using Base = MmaTensorOp; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Base::ArchMmaOperator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename Base::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = typename Base::OperatorClass; - - /// Shape of underlying instruction - using InstructionShape = typename Base::InstructionShape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Base::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Base::kTransformB; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// Sparsity in Operand A - static int const kSparse = Policy::Operator::kSparse; - - /// Meta data size in bits - static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; - - /// Max ID2 - static int const kMaxID2 = Policy::Operator::kMaxID2; - - static int const kVerticalVisit = false; - /// Data type of meta E that is moved at the same time - using ElementE = - typename cutlass::platform::conditional::type; - - /// Number of ElementA that is associated with one ElementE - static int const kElementsPerElementE = - 128 / cutlass::sizeof_bits::value; - - /// Meta data is essentially interleaved but mapped to ColumnMajor internally - static int const kInterleaved = 2; - - /// Layout of meta E - using LayoutE = cutlass::layout::ColumnMajor; - - public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = - Array; - - /// Iterates over the B operand in memory - using IteratorB = typename Base::IteratorB; - - /// Storage for B tile - using FragmentB = typename Base::FragmentB; - - /// Storage for transformed B tile - using TransformedFragmentB = typename Base::TransformedFragmentB; - - /// Iterates over the C operand in memory - using IteratorC = typename Base::IteratorC; - - /// Storage for C tile - using FragmentC = typename Base::FragmentC; - - /// Iterates over the E operand in memory - using IteratorE = SparseMmaTensorOpMetaTileIterator< - MatrixShape, - ElementE, LayoutE, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for E tile - using FragmentE = typename IteratorE::Fragment; - - /// Number of mma operations performed - using MmaIterations = typename Base::MmaIterations; - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - SparseMmaTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C, - FragmentE const &E - ) const { - - using MmaOperandA = typename Policy::Operator::FragmentA; - using MmaOperandB = typename Policy::Operator::FragmentB; - using MmaOperandC = typename Policy::Operator::FragmentC; - using MmaOperandE = typename Policy::Operator::FragmentE; - - D = C; - - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - MmaOperandE const *ptr_E = reinterpret_cast(&E); - - if (kVerticalVisit) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - int id2 = m_serpentine % kMaxID2; - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_E[(m_serpentine / kMaxID2)], - id2); - } else { - mma( - ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_E[(m_serpentine / kMaxID2)], - id2); - } - } - } - } else { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int id2 = m % kMaxID2; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_E[(m / kMaxID2)], - id2); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_E[(m / kMaxID2)], - id2); - } - } - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - // - // Define conversions from source type to instruction type - // - FloatRoundStyle const kRoundA = - PreferredRoundingMode::kRound; - FloatRoundStyle const kRoundB = - PreferredRoundingMode::kRound; - - if (kVerticalVisit) { - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_B = - reinterpret_cast const *>(&B); - Array * - ptr_dst_B = reinterpret_cast *>(&dst_B); - - dst_A = convert_A(A); - - ptr_dst_B[0] = convert_B(ptr_B[0]); - ptr_dst_B[1] = convert_B(ptr_B[1]); - } else { - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_A = - reinterpret_cast const *>(&A); - Array * - ptr_dst_A = reinterpret_cast *>(&dst_A); - - dst_B = convert_B(B); - - ptr_dst_A[0] = convert_A(ptr_A[0]); - ptr_dst_A[1] = convert_A(ptr_A[1]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h deleted file mode 100644 index 190e92fc5a036e2ce038983130e07c27e25deced..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h +++ /dev/null @@ -1,417 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - return converter(source); - } -}; - -template -struct ConvertAndPack { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - return source; - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting Tensor Cores. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaTensorOp { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ < 800) || (__CUDA_ARCH__ == 890)) - static int const kVerticalVisit = true; - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1200) - static int const kVerticalVisit = true; - #else - static int const kVerticalVisit = false; - #endif - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = - Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = - Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - D = C; - - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - - if (kVerticalVisit) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } else { - mma( - ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } - } - } else { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } - } - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - // - // Define conversions from source type to instruction type - // - FloatRoundStyle const kRoundA = - PreferredRoundingMode::kRound; - FloatRoundStyle const kRoundB = - PreferredRoundingMode::kRound; - if (kVerticalVisit) { - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_B = - reinterpret_cast const *>(&B); - Array * - ptr_dst_B = reinterpret_cast *>(&dst_B); - - dst_A = convert_A(A); - - ptr_dst_B[0] = convert_B(ptr_B[0]); - ptr_dst_B[1] = convert_B(ptr_B[1]); - } else { - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_A = - reinterpret_cast const *>(&A); - Array * - ptr_dst_A = reinterpret_cast *>(&dst_A); - - dst_B = convert_B(B); - - ptr_dst_A[0] = convert_A(ptr_A[0]); - ptr_dst_A[1] = convert_A(ptr_A[1]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h deleted file mode 100644 index 570298bccdae2e014a32b8ad31b32d84bd4332bd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h +++ /dev/null @@ -1,471 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -enum class TensorFloat32Op { - k3xTF32, - k4xTF32 -}; - -template < - /// Floating-point rounding style - FloatRoundStyle RoundBigA_, - /// Floating-point rounding style - FloatRoundStyle RoundSmallA_, - /// Floating-point rounding style - FloatRoundStyle RoundBigB_ = RoundBigA_, - /// Floating-point rounding style - FloatRoundStyle RoundSmallB_ = RoundSmallA_, - /// Precision for TensorFloat32Op - // (k3xTF32: BigxBig, BigxSmall, SmallxBig) - // (k4xTF32: BigxBig, BigxSmall, SmallxBig, SmallxSmall) - TensorFloat32Op Precision_ = TensorFloat32Op::k3xTF32 - > -struct FastF32 { - - static FloatRoundStyle const kRoundBigA = RoundBigA_; - static FloatRoundStyle const kRoundSmallA = RoundSmallA_; - static FloatRoundStyle const kRoundBigB = RoundBigB_; - static FloatRoundStyle const kRoundSmallB = RoundSmallB_; - static TensorFloat32Op const kPrecision = Precision_; -}; - - -namespace detail { - - template< - int N, - FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, - FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate - > - struct ConvertAndPackAccurateF32 { - - /// Rounding styles for big and small part - static FloatRoundStyle const kRoundBig = RoundBig; - static FloatRoundStyle const kRoundSmall = RoundSmall; - - /// Converter type - using Converter = NumericConverterFastF32; - - /// Source fragement - using SourceFragment = Array; - - /// Destination fragment - using DestinationFragment = Array; - - /// Converter Fragment holding two tfloat32_t elements for every float - using ConverterFragment = Array; - - /// Index in fargments for the big and small part - static int const kBigIndex = 0; - static int const kSmallIndex = 1; - - CUTLASS_HOST_DEVICE - void operator()(SourceFragment const &source, - DestinationFragment &dst_big, - DestinationFragment &dst_small) { - - Converter convert_; - ConverterFragment result_; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - // convert source to result fragment - result_ = convert_(source[i]); - - // store converted result fragments to destination fragment - dst_big[i] = result_[kBigIndex]; - dst_small[i] = result_[kSmallIndex]; - } - } - }; -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaTensorOpFastF32; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float*float+float => float using TF32 TensorOps -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor, - /// Used for partial specialization - typename Enable -> -class MmaTensorOpFastF32< - Shape_, - float, LayoutA_, - float, LayoutB_, - float, LayoutC_, - Policy_, PartitionsK_, - AccumulatorsInRowMajor, Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = float; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = float; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = float; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = arch::OpMultiplyAddFastF32; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// Tune F32 to TF32 big small conversion for float operation - /// Different combination of big small conversin can cause different tradeoff - /// between speed and accuracy. Generally, use round_half_ulp_truncate can - /// improve the performance but hur the accuracy. - using MmaFastF32 = FastF32 < - FloatRoundStyle::round_toward_zero, // kRoundBigA - FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA - FloatRoundStyle::round_toward_zero, // kRoundBigB - FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB - TensorFloat32Op::k3xTF32 // Number of TF32 operations - >; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = - Array; - - /// Fragment bisecting big and small sections - using AccessTypeFragmentA = - Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = - Array; - - /// Fragment bisecting big and small sections - using AccessTypeFragmentB = - Array; - - /// Index in fargments for the big and small part - static int const kBigIndex = 0; - static int const kSmallIndex = 1; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpFastF32() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - - AccessTypeFragmentA const *ptr_A = reinterpret_cast(&A); - AccessTypeFragmentB const *ptr_B = reinterpret_cast(&B); - - // - // Accumulate in place - // - D = C; - - mma_operator(D, ptr_A[kSmallIndex], ptr_B[kBigIndex], D); - - mma_operator(D, ptr_A[kBigIndex], ptr_B[kSmallIndex], D); - - mma_operator(D, ptr_A[kBigIndex], ptr_B[kBigIndex], D); - - if (MmaFastF32::kPrecision == TensorFloat32Op::k4xTF32) - mma_operator(D, ptr_A[kSmallIndex], ptr_B[kSmallIndex], D); - } - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void mma_operator( - FragmentC &D, - AccessTypeFragmentA const &A, - AccessTypeFragmentB const &B, - FragmentC const &C - ) const { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // This allows to reuse of Rb when at serpentine turns - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma( - ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } // end n loop - } // end m loop - #else - assert(0); - #endif - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - // - // Define conversions from source type to instruction type - // - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - detail::ConvertAndPackAccurateF32< - FragmentA::kElements / 2, - MmaFastF32::kRoundBigA, - MmaFastF32::kRoundSmallA> convert_A; - - detail::ConvertAndPackAccurateF32< - FragmentB::kElements, - MmaFastF32::kRoundBigB, - MmaFastF32::kRoundSmallB> convert_B; - - Array *ptr_dst_B = - reinterpret_cast *>(&dst_B); - - convert_B(B, ptr_dst_B[0], ptr_dst_B[1]); - - Array *ptr_dst_A = - reinterpret_cast *>(&dst_A); - - Array const *ptr_A = - reinterpret_cast const *>(&A); - - convert_A(ptr_A[0], ptr_dst_A[0], ptr_dst_A[2]); - - convert_A(ptr_A[1], ptr_dst_A[1], ptr_dst_A[3]); - #else - assert(0); - #endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h deleted file mode 100644 index c70bc581dd5a77d9d17c533717d8a7b3693b55ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ /dev/null @@ -1,559 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief This defines a "fragment" iterator for visiting the fragments of a warp tile - that participate in one warp-level mma operation. - - Typically, this is used to access the accumulator tile/fragment of a warp-level mma operation. - The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into - next warp-level mma operation. - - This iterator is necessary to accomplish warp-level mma fusion where the accumulator tile is - reused as multiplicand tile for the next mma. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_conversion.h" - -namespace cutlass { -namespace gemm { -namespace warp { - - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Size of the accumulation tile shape (concept: MatrixShape) - typename AccumulatorShape_, - /// KBlocks columns to compute residual - int KBlocksColumn_, - /// Accumulator Element type - typename ElementAccumulator_, - /// Element type - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Output operation on the fragment - typename OutputOp_> -class MmaTensorOpFragmentIterator; - - -// Partial specialization for col-major accumulator tile - -template < - /// Shape of warp tile to load (concept: MatrixShape) - typename Shape_, - /// Shape of the warp accumulation tile (concept: MatrixShape) - typename AccumulatorShape_, - /// KBlocks columns to compute residual - int KBlocksColumn_, - /// Accumulator Element type - typename ElementAccumulator_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Output operation on fragment - typename OutputOp_> -class MmaTensorOpFragmentIterator { - public: - - /// Shape of warp tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Shape of the warp accumulation tile (concept: MatrixShape) - using AccumulatorShape = AccumulatorShape_; - - /// KBlocks columns to compute residual - static int const kKBlockColumn = KBlocksColumn_; - - /// Accumulator Element type - using ElementAccumulator = ElementAccumulator_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Output operation on fragment - using OutputOp = OutputOp_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - static_assert( - AccumulatorShape::kRow == Shape::kRow, - "Rows of Warp Accumulator must be the same as rows of warp"); - static_assert( - !(AccumulatorShape::kColumn % Shape::kColumn), - "Shape of Warp Accumulator must be divisible by warp shape."); - static_assert( - !(kKBlockColumn % Shape::kColumn), - "KBlock size must be divisible by warp shape."); - - /// Number of times this iterator can be incremented - static int const kIterations = AccumulatorShape::kCount / Shape::kCount; - }; - -private: - - static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; - - /// Number of mma operations performed by a warp - using MmaIterations = MatrixShape; - /// Number of mma operations performed by the entire accumulator - using AccumulatorIterations = MatrixShape; - - /// Number of K iterations - static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; - static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; - static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn - * (AccumulatorShape::kRow / Shape::kRow); - static int const kResidualIndex = kResidualColumn / Shape::kColumn - * (AccumulatorShape::kRow / Shape::kRow); - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - /// This is the fragment size produced by one access of the iterator. - using Fragment = Array; - - /// Accumulator Fragment object - using AccumulatorFragment = Array; - - /// Scale Bias Element Type - using ElementScaleBias = typename OutputOp::ElementCompute; - - /// Scale Bias Fragment object - using ScaleBiasFragment = Array; - - -private: - - /// Internal access type - using AccessType = Array; - using FragmentAccessType = Array; - - using ScaleBiasAccessType = Array; - -private: - // - // Data members - // - - /// Accumulator tile - AccessType const *accumulators_; - - /// Internal index - int index_; - - /// Used to access residual tile first - bool is_residual_tile_; - -public: - /// Constructs an iterator - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) - : accumulators_(reinterpret_cast(&accum)), - index_(0), is_residual_tile_(true) {} - - /// Add offset - CUTLASS_HOST_DEVICE - void add_offset(int index_offset) { - index_ += index_offset; - if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { - index_ = index_ - kKBlockColumnIterations + kResidualIndex; - is_residual_tile_ = false; - } - } - - /// Increments - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator &operator++() { - add_offset(1); - return *this; - } - - /// Decrements - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator &operator--() { - add_offset(-1); - return *this; - } - - /// Loads a fragment from the referenced part of the accumulator tile - CUTLASS_HOST_DEVICE - void load(Fragment &frag, OutputOp output_op) const { - - if (output_op.is_source_needed()) //beta must be zero - assert(0); - - FragmentAccessType *frag_ptr = reinterpret_cast(&frag); - - int index = index_ * MmaIterations::kCount; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; n++) { - for (int m = 0; m < MmaIterations::kRow; m++) { - int accumulator_access_offset = - n * AccumulatorIterations::kRow + m + index; - - frag_ptr[m * MmaIterations::kColumn + n].clear(); - if(!(is_residual_tile_ && index_ >= kResidualIndex)) - frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset]); - } - } - } - - /// Loads a fragment from the referenced part of the accumulator tile - /// Then apply per-channel scale and bias - CUTLASS_HOST_DEVICE - void load(Fragment &frag, ScaleBiasFragment &scale, - ScaleBiasFragment &bias, OutputOp output_op) const { - - if (output_op.is_source_needed()) //beta must be zero - assert(0); - - FragmentAccessType *frag_ptr = reinterpret_cast(&frag); - ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); - ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); - - int index = index_ * MmaIterations::kCount; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; n++) { - for (int m = 0; m < MmaIterations::kRow; m++) { - int accumulator_access_offset = - n * AccumulatorIterations::kRow + m + index; - - frag_ptr[m * MmaIterations::kColumn + n].clear(); - if(!(is_residual_tile_ && index_ >= kResidualIndex)) - frag_ptr[m * MmaIterations::kColumn + n] = - output_op(accumulators_[accumulator_access_offset], - scale_ptr[n] /*scale*/, bias_ptr[n] /*bias*/); - } - } - } - - - -}; - -// Partial specialization for row-major accumulator tile - -template < - /// Shape of warp tile to load (concept: MatrixShape) - typename Shape_, - /// Shape of the warp accumulation tile (concept: MatrixShape) - typename AccumulatorShape_, - /// KBlocks columns to compute residual - int KBlocksColumn_, - /// Accumulator Element type - typename ElementAccumulator_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Output operation on fragment - typename OutputOp_> -class MmaTensorOpFragmentIterator { - public: - - /// Shape of warp tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Shape of the warp accumulation tile (concept: MatrixShape) - using AccumulatorShape = AccumulatorShape_; - - /// KBlocks columns to compute residual - static int const kKBlockColumn = KBlocksColumn_; - - /// Accumulator Element type - using ElementAccumulator = ElementAccumulator_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Output operation on fragment - using OutputOp = OutputOp_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - static_assert( - AccumulatorShape::kRow == Shape::kRow, - "Rows of Warp Accumulator must be the same as rows of warp"); - static_assert( - !(AccumulatorShape::kColumn % Shape::kColumn), - "Shape of Warp Accumulator must be divisible by warp shape."); - static_assert( - !(kKBlockColumn % Shape::kColumn), - "KBlock size must be divisible by warp shape."); - - /// Number of times this iterator can be incremented - static int const kIterations = AccumulatorShape::kCount / Shape::kCount; - }; - -private: - - static int const kRowsPerIteration = 8; - static int const kColumnsPerIteration = 16; - static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kN / kThreads; - static int const kElementsPerAccess = kRowsPerIteration * kColumnsPerIteration / kThreads; - static int const kIterationsPerAccess = kElementsPerAccess / kElementsPerIteration; - - // Number of iterations per actual instruction - static int const kIterationsPerInstruction = InstructionShape::kM / kRowsPerIteration; - - static int const kAccessStride = kIterationsPerInstruction; - - /// Number of mma operations performed by a warp - using MmaIterations = MatrixShape; - /// Number of mma operations performed by the entire accumulator - using AccumulatorIterations = MatrixShape; - - /// Number of Accesses in a warp - using AccessIterations = MatrixShape; - - /// Number of K iterations - static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; - static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; - static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn; - static int const kResidualIndex = kResidualColumn / Shape::kColumn; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - /// This is the fragment size produced by one access of the iterator. - using Fragment = Array; - - /// Accumulator Fragment object - using AccumulatorFragment = Array; - - /// Scale Bias Element Type - using ElementScaleBias = typename OutputOp::ElementCompute; - - /// Scale Bias Fragment object - using ScaleBiasFragment = Array; - - -private: - - /// Internal access type - using AccessType = Array; - using FragmentAccessType = Array; - using ScaleBiasAccessType = Array; - -private: - // - // Data members - // - - /// Accumulator tile - AccessType const *accumulators_; - - /// Internal index - int index_; - - /// Used to access residual tile first - bool is_residual_tile_; - -public: - /// Constructs an iterator - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) - : accumulators_(reinterpret_cast(&accum)), - index_(0), is_residual_tile_(true) {} - - /// Add offset - CUTLASS_HOST_DEVICE - void add_offset(int index_offset) { - index_ += index_offset; - if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { - index_ = index_ - kKBlockColumnIterations + kResidualIndex; - is_residual_tile_ = false; - } - } - - /// Increments - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator &operator++() { - add_offset(1); - return *this; - } - - /// Decrements - CUTLASS_HOST_DEVICE - MmaTensorOpFragmentIterator &operator--() { - add_offset(-1); - return *this; - } - - CUTLASS_HOST_DEVICE - void set_index(int idx) { - index_ = idx; - } - - /// Loads a fragment from the referenced part of the accumulator tile - CUTLASS_HOST_DEVICE - void load(Fragment &frag, OutputOp output_op) const { - - if (output_op.is_source_needed()) //beta must be zero - assert(0); - - FragmentAccessType *frag_ptr = reinterpret_cast(&frag); - - int index = index_ * AccessIterations::kCount; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccessIterations::kCount; i++) { - - int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + - (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * - AccumulatorIterations::kColumn * kIterationsPerInstruction + - (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * - (kIterationsPerInstruction * kIterationsPerAccess) + - (index % kIterationsPerInstruction); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kIterationsPerAccess; j++) { - - frag_ptr[i*kIterationsPerAccess + j].clear(); - if(!(is_residual_tile_ && index_ >= kResidualIndex)) - frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride]); - } - index++; - } - } - - /// Loads a fragment from the referenced part of the accumulator tile - /// Then apply per-channel scale and bias - CUTLASS_HOST_DEVICE - void load(Fragment &frag, ScaleBiasFragment &scale, - ScaleBiasFragment & bias, OutputOp output_op) const { - - if (output_op.is_source_needed()) //beta must be zero - assert(0); - - FragmentAccessType *frag_ptr = reinterpret_cast(&frag); - ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); - ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); - - int index = index_ * AccessIterations::kCount; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccessIterations::kCount; i++) { - - int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + - (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * - AccumulatorIterations::kColumn * kIterationsPerInstruction + - (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * - (kIterationsPerInstruction * kIterationsPerAccess) + - (index % kIterationsPerInstruction); - - int scale_bias_offset = (index - % (kIterationsPerInstruction * AccessIterations::kColumn)) - * kIterationsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kIterationsPerAccess; j++) { - - - frag_ptr[i*kIterationsPerAccess + j].clear(); - if(!(is_residual_tile_ && index_ >= kResidualIndex)) - frag_ptr[i*kIterationsPerAccess + j] = output_op( - accumulators_[accumulator_access_offset + j * kAccessStride], - scale_ptr[scale_bias_offset + j], bias_ptr[scale_bias_offset + j]); - } - index++; - } - } - -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h deleted file mode 100644 index febd0e48be683db49b588d2e5c1d56de39d2ad13..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h +++ /dev/null @@ -1,65 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Policy describing implementation details of warp-level GEMM targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/gemm/gemm.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Policy -template < - typename Operator_, ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) - typename OpDelta_ ///< distance between operations (concept: MatrixShape) -> -struct MmaTensorOpPolicy { - - using Operator = Operator_; ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) - using OpDelta = OpDelta_; ///< distance between operations (concept: MatrixShape) - using MmaShape = typename Operator::Shape; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h deleted file mode 100644 index e7a4d87f99ae8ff97e8ca615a74c923e2f745fc9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h +++ /dev/null @@ -1,280 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. - - This is a work in progress. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/mma.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Used for partial specialization - typename Enable = bool -> -class MmaVoltaTensorOp { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Architecture tag - using ArchTag = arch::Sm70; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Underlying instruction shape - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// interleaved 32x32 tiles - using InterleavedTileShape = GemmShape<32, 32, 4>; - - static_assert(!(Shape::kM % InterleavedTileShape::kM) && - !(Shape::kN % InterleavedTileShape::kN), - "Shape must be a multiple of InterleavedTileShape."); -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaVoltaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape< - ArchMmaOperator::Shape::kM, - ArchMmaOperator::Shape::kK - >, - Policy::OpDelta::kRow, - kThreadCount - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Iterates over the B operand in memory - using IteratorB = MmaVoltaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape< - ArchMmaOperator::Shape::kK, - ArchMmaOperator::Shape::kN - >, - Policy::OpDelta::kRow, - kThreadCount - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Iterates over the C operand in memory - using IteratorC = MmaVoltaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta - >; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - -private: - - static_assert( - !(Shape::kM % ArchMmaOperator::Shape::kM) && - !(Shape::kN % ArchMmaOperator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - InterleavedTileShape::kM / ArchMmaOperator::Shape::kM, - InterleavedTileShape::kN / ArchMmaOperator::Shape::kN - >; - using TileIterations = MatrixShape< - Shape::kM / InterleavedTileShape::kM, - Shape::kN / InterleavedTileShape::kN - >; - - // Whether matrix B is reordered - bool reorder_B_; - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaVoltaTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - D = C; - - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - CUTLASS_PRAGMA_UNROLL - for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) { - CUTLASS_PRAGMA_UNROLL - for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) { - CUTLASS_PRAGMA_UNROLL - for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) { - CUTLASS_PRAGMA_UNROLL - - for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) { - - int op_col = inner_col + MmaIterations::kColumn * outer_col; - - // Column-major serpentine sequence to maximize reuse of A operand. - int inner_row_serp = inner_row; - int outer_row_serp = outer_row; - if (op_col & 1) { - inner_row_serp = MmaIterations::kRow - inner_row - 1; - outer_row_serp = TileIterations::kRow - outer_row - 1; - } - int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp; - int op_idx = inner_row_serp + MmaIterations::kRow * - (inner_col + MmaIterations::kColumn * - (outer_row_serp + TileIterations::kRow * outer_col)); - mma( - ptr_D[op_idx], - ptr_A[op_row], - ptr_B[op_col], - ptr_D[op_idx]); - - } - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h deleted file mode 100644 index f37c5c1434c0f1887ce70ae8a11eea25b6c293d6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h +++ /dev/null @@ -1,362 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - - -/// Tile access iterator -/// Each iteration access in the tile is -/// used as multiplicand for one -/// warp-level matrix multiplication -template < - /// Size of the tile (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand_, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: MatrixShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads = 32, - /// Enable Residual Support - bool EnableResidual = false, - /// Number of partitions along K dimension - int PartitionsK_ = 1 -> -class MmaTensorOpMultiplicandTileAccessIterator { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - /// Basic check - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Number of elements accessed per Shared Memory load - static int const kElementsPerAccess = - (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); - - using InstructionCount = MatrixShape< - Shape::kRow / InstructionShape::kRow, - Shape::kColumn / InstructionShape::kColumn - >; - - static int const kIterations = (kOperand == Operand::kA) ? - InstructionCount::kColumn : InstructionCount::kRow; - - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - (kOperand == Operand::kA) ? - (Shape::kRow * InstructionShape::kColumn / kThreads) : - (Shape::kColumn * InstructionShape::kRow / kThreads) - >; - - /// Memory access type - using AccessType = AlignedArray; - -private: - - /// Underlying tensor reference - TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to load residual tile - bool is_residual_; - - /// residual offset of each thread - TensorCoord residual_offset_; - - /// Iterations in a tile - int iterations_; - -public: - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileAccessIterator( - TensorRef const &ref, - TensorCoord extent, - int lane_id - ): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) { - - if (kOperand == Operand::kA) { - origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); - } - else { - origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); - } - - ref_.add_coord_offset(origin_); - - if(EnableResidual) { - // compute residual offset - if (kOperand == Operand::kA) { - typename TensorCoord::Index residual_size = - extent_.column() % Shape::kColumn; - if(residual_size) { - is_residual_ = true; - residual_offset_ = make_Coord(0, residual_size); - } - } - else { - typename TensorCoord::Index residual_size = - extent_.row() % Shape::kRow; - if(residual_size) { - is_residual_ = true; - residual_offset_ = make_Coord(residual_size, 0); - } - } - } - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileAccessIterator( - TensorRef const &ref, - int lane_id - ): MmaTensorOpMultiplicandTileAccessIterator(ref, - {Shape::kRow, Shape::kColumn}, lane_id) { - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - void advance() { - - if(EnableResidual && is_residual_) { - is_residual_ = false; - - origin_ += residual_offset_; - ref_.add_coord_offset(residual_offset_); - - } - - else { - if (kOperand == Operand::kA) { - add_tile_offset({0, 1}); - } - else { - add_tile_offset({1, 0}); - } - } - - iterations_ = 0; - } - - /// increase iterations in a tile - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileAccessIterator & operator++() { - - iterations_++; - - if(iterations_ >= kIterations) - advance(); - - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - int const kWarpShapeDivisibleInner = - (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); - - // Take advantage of Tensor Op's 8 x 4T access pattern - int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; - - AccessType *access_ptr = reinterpret_cast(&frag); - - if (kOperand == Operand::kA) { - int const kTilesPerInstruction = InstructionShape::kRow / 8; - - CUTLASS_PRAGMA_UNROLL - for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { - int access_idx = - access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); - - MatrixCoord offset( - access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, - inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn); - - MatrixCoord access_coord = origin_ + offset; - -// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { - - access_ptr[access_idx] = *reinterpret_cast( - ref_.data() + ref_.offset(offset)); -// } -// else { -// AccessType zero; -// zero.clear(); -// access_ptr[access_idx] = zero; -// } - } - } - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { - int access_idx = inner_idx + kAccessesInner * inst_n_idx; - - MatrixCoord offset( - inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow, - inst_n_idx * 8); - - MatrixCoord access_coord = origin_ + offset; - -// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { - - access_ptr[access_idx] = *reinterpret_cast( - ref_.data() + ref_.offset(offset)); -// } -// else { -// AccessType zero; -// zero.clear(); -// access_ptr[access_idx] = zero; -// } - } - } - } - } - -}; - - - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h deleted file mode 100644 index dd15097d3ebd0e2e4c663c9ee57e0e6520eb6b6b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ /dev/null @@ -1,4803 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class MmaTensorOpMultiplicandTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous::value, - 64>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous< - sizeof_bits::value, 64>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Determine number of elements along outer dimension per individual LDSM op - static int const kLdsmOpOuter = Layout::kElementsPerAccess; - static int const kLdsmOpInner = 8; - - static_assert(!(Shape::kContiguous % kLdsmOpOuter), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - static_assert(!(Shape::kStrided % kLdsmOpInner), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - /// Shape of one individual LDSM instruction - static int const LdsmShapeStrided = - InstructionShape::kStrided / kLdsmOpInner; - static int const LdsmShapeContiguous = 4 / LdsmShapeStrided; - using LdsmShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDSM instructions - using LdsmIterations = layout::PitchLinearShape< - Shape::kContiguous / Layout::kElementsPerAccess / LdsmShapeContiguous, - 1>; - - /// Number of groups for each tile - static int const kGroupsPerTile = - Shape::kStrided / InstructionShape::kStrided; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Number of internal pointers needed to reference shared memory - static int const kPointerCount = - Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous; - - /// Pointer type used for accesses - using AccessType = Array; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_[kPointerCount]; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), - byte_offset_(0), - k_group_idx_(0) { - - int quad_pair = (lane_id >> 3); - int quad_quad = (lane_id >> 4); - int lane_in_quad = (lane_id & 3); - int lane_in_quad_pair = (lane_id & 7); - int lane_in_quad_quad = (lane_id & 15); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount; ++i) { - int partition_contiguous_idx = -1; - int access_contiguous_idx = -1; - int access_strided_idx = -1; - - if (Policy::LdsmShape::kContiguous == 4) { - // Matrix multiply 1688 A/B - // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block). - // Four blocks are next to each other in the contiguous dimension. - partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i); - access_contiguous_idx = (quad_pair ^ lane_in_quad); - access_strided_idx = lane_in_quad_pair; - } else if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kA) { - // Matrix multiply 16816 A - // Q0 Q1 - // Q2 Q3 - partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); - access_contiguous_idx = - (((quad_pair & 1) + ((i & 1) << 1)) ^ lane_in_quad); - access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); - } else if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kB) { - // Matrix multiply 16816 B - // Q0 Q2 - // Q1 Q3 - partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); - access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad); - access_strided_idx = lane_in_quad_quad; - } else if (Policy::LdsmShape::kContiguous == 1) { - // Matrix multiply 16832.SP B - // Q0 - // Q1 - // Q2 - // Q3 - partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2)); - access_contiguous_idx = ((i & 3) ^ lane_in_quad); - access_strided_idx = lane_id; - } - - int access_contiguous = - partition_contiguous_idx * Layout::PartitionShape::kContiguous + - access_contiguous_idx; - - int access_strided = access_strided_idx; - - pointer_[i] = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - if (Shape::kContiguous == - Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) { - if (tile_offset.contiguous() % 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount / 2; ++i) { - AccessType const *tmp_pointer = pointer_[i]; - pointer_[i] = pointer_[i + kPointerCount / 2]; - pointer_[i + kPointerCount / 2] = tmp_pointer; - } - } - contiguous_offset = (tile_offset.contiguous() >> 1) << 1; - } - - int offset = (tile_offset.strided() * InstructionShape::kStrided) * - stride_ * Layout::kElementsPerAccess + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == Policy::kGroupsPerTile) { - k_group_idx_ = 0; - add_tile_offset( - {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); - } - } - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_[c % kPointerCount] + - Layout::TileShape::kContiguous * (c / kPointerCount) + - Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], - source_byte_ptr - ); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no op - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread MMA.TF32 NT TensorOps. It -/// uses LDS.32 to load from shared memory and therefore must be initialized -/// with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous<32, 32>, InstructionShape_, - OpDelta_, 32, PartitionsK_> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for " - "A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous<32, 32>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Determine number of elements along outer dimension per individual 32bit - // shared memory load op. Every one warp of 32bit shared memory load loads - // 8x4 elements - static int const kLdsOpInner = Layout::TileShape::kStrided; - static int const kLdsOpOuter = kThreads / kLdsOpInner; - - static_assert(!(Shape::kContiguous % kLdsOpOuter), - "Shape of warp-level mma must be divisible by 32bit " - "fundamental tile size."); - - static_assert(!(Shape::kStrided % kLdsOpInner), - "Shape of warp-level mma must be divisible by 32bit " - "fundamental tile size."); - - /// Number of 32 bit shared memory load instructions needed by one MMA instruction - /// 1688 A 2x2 - /// 1688 B 1x2 - /// 16816 B 1x4 - static int const LdsShapeContiguous = - InstructionShape::kContiguous / kLdsOpOuter; - static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner; - using LdsShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDS instructions - using LdsIterations = layout::PitchLinearShape< - Shape::kContiguous / LdsShapeContiguous / kLdsOpOuter, 1>; - - /// Number of groups for each tile - static int const kGroupsPerTile = - Shape::kStrided / InstructionShape::kStrided; - }; - - private: - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Number of internal pointers needed to reference shared memory - static int const kPointerCount = Layout::TileShape::kContiguous * - Layout::kElementsPerAccess / - Policy::kLdsOpOuter; - - /// Vectorized access is not used - static int const kElementsPerAccess = 1; - - /// Pointer type used for accesses - using AccessType = Element; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - - private: - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_[kPointerCount]; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {} - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : stride_(ref.stride(0)), byte_offset_(0), k_group_idx_(0) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount; ++i) { - int access_strided = lane_id % Policy::kLdsOpInner; - int access_contiguous = (lane_id / Policy::kLdsOpInner) + - (access_strided ^ i) * Policy::kLdsOpOuter; - - pointer_[i] = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - int contiguous_offset = tile_offset.contiguous(); - if (Shape::kContiguous == - Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) { - if (tile_offset.contiguous() % 2) { - // Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5] - // pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7] - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount / 2; ++i) { - AccessType const *tmp_pointer = pointer_[i]; - pointer_[i] = pointer_[i + kPointerCount / 2]; - pointer_[i + kPointerCount / 2] = tmp_pointer; - } - } - contiguous_offset = (tile_offset.contiguous() >> 1) << 1; - } - - int offset = (tile_offset.strided() * InstructionShape::kStrided) * stride_ + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator++() { - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == Policy::kGroupsPerTile) { - k_group_idx_ = 0; - add_tile_offset( - {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); - } - } - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator--() { - byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * - kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - Element *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int ss = 0; ss < Policy::LdsShape::kStrided; ++ss) { - CUTLASS_PRAGMA_UNROLL - for (int cc = 0; cc < Policy::LdsShape::kContiguous; ++cc) { - int access_idx = - cc + (ss + (c + s * Policy::LdsIterations::kContiguous) * - Policy::LdsShape::kStrided) * - Policy::LdsShape::kContiguous; - int access_idx_contiguous = cc + c * Policy::LdsShape::kContiguous; - int access_idx_strided = - (ss + s * Policy::LdsShape::kStrided) * Policy::kLdsOpInner; - - AccessType const *source_ptr = - pointer_[access_idx_contiguous % kPointerCount] + - Layout::TileShape::kContiguous * Layout::kElementsPerAccess * - (access_idx_contiguous / kPointerCount) + - access_idx_strided * stride_; - - char const *source_byte_ptr = - reinterpret_cast(source_ptr) + byte_offset + - byte_offset_; - - fetch_ptr[access_idx] = - *reinterpret_cast(source_byte_ptr); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no op - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps with 64B warp tile -/// the contiguous dimension. This assumes Threadblock contiguous dimension has -/// the same size as the warp tile. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// This specialization can be merged into the general one. Most code is the same. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous<16, 32>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Element number when the layout crosses - static int const kCrosswise = 32; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Determine number of elements along outer dimension per individual LDSM op - static int const kLdsmOpOuter = Layout::kElementsPerAccess; - static int const kLdsmOpInner = 8; - - static_assert(!(Shape::kContiguous % kLdsmOpOuter), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - static_assert(!(Shape::kStrided % kLdsmOpInner), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - /// Shape of one individual LDSM instruction - static int const LdsmShapeStrided = - InstructionShape::kStrided / kLdsmOpInner; - static int const LdsmShapeContiguous = 4 / LdsmShapeStrided; - using LdsmShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDSM instructions - using LdsmIterations = layout::PitchLinearShape< - Shape::kContiguous / Layout::kElementsPerAccess / LdsmShapeContiguous, - 1>; - - /// Number of groups for each tile - static int const kGroupsPerTile = - Shape::kStrided / InstructionShape::kStrided; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Number of internal pointers needed to reference shared memory - static int const kPointerCount = - Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous / Layout::kFactor; - - /// Pointer type used for accesses - using AccessType = Array; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_[kPointerCount]; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), - byte_offset_(0), - k_group_idx_(0) { - - int quad_pair = (lane_id >> 3); - int quad_quad = (lane_id >> 4); - //int lane_in_quad = (lane_id & 3); - int lane_in_quad_pair = (lane_id & 7); - int lane_in_quad_quad = (lane_id & 15); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount; ++i) { - int partition_contiguous_idx = -1; - int access_contiguous_idx = -1; - int access_strided_idx = -1; - - if (Policy::LdsmShape::kContiguous == 4) { - // Matrix multiply 1688 A/B - // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block). - // Four blocks are next to each other in the contiguous dimension. - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = quad_pair ^ (lane_in_quad_pair / Layout::kFactor); - access_strided_idx = lane_in_quad_pair / Layout::kFactor; - } else if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kA) { - // Matrix multiply 16816 A - // Q0 Q1 - // Q2 Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = - (((quad_pair & 1) + i * 2) ^ (lane_in_quad_pair / Layout::kFactor)); - access_strided_idx = (lane_in_quad_pair + (lane_id >> 4 << 3)) / 2; - } else if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kB) { - // Matrix multiply 16816 B - // Q0 Q2 - // Q1 Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = (quad_quad + i * 2) ^ (lane_in_quad_pair / Layout::kFactor); - access_strided_idx = (lane_in_quad_quad / Layout::kFactor); - } else if (Policy::LdsmShape::kContiguous == 1) { - // Matrix multiply 16832.SP B - // Q0 - // Q1 - // Q2 - // Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor) ^ i; - access_strided_idx = lane_id / Layout::kFactor; - } - - int access_contiguous = - partition_contiguous_idx * Layout::PartitionShape::kContiguous + - access_contiguous_idx; - - int access_strided = access_strided_idx; - - pointer_[i] = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - if (Shape::kContiguous == - Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) { - if (tile_offset.contiguous() % 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount / 2; ++i) { - AccessType const *tmp_pointer = pointer_[i]; - pointer_[i] = pointer_[i + kPointerCount / 2]; - pointer_[i + kPointerCount / 2] = tmp_pointer; - } - } - contiguous_offset = (tile_offset.contiguous() >> 1) << 1; - } - - int offset = (tile_offset.strided() * InstructionShape::kStrided) * - stride_ * Layout::kElementsPerAccess / Layout::kFactor + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == Policy::kGroupsPerTile) { - k_group_idx_ = 0; - add_tile_offset( - {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); - } - } - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_[c % kPointerCount] + - Layout::TileShape::kContiguous * (c / kPointerCount) + - Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_ / Layout::kFactor; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], - source_byte_ptr - ); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_ / Layout::kFactor; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no op - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps with 32B warp tile -/// the contiguous dimension. This assumes Threadblock contiguous dimension has -/// the same size as the warp tile. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// This specialization can be merged into the general one. Most code is the same. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous<16, 16>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Element number when the layout crosses - static int const kCrosswise = 16; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Determine number of elements along outer dimension per individual LDSM op - static int const kLdsmOpOuter = Layout::kElementsPerAccess; - static int const kLdsmOpInner = 8; - - static_assert(!(Shape::kContiguous % kLdsmOpOuter), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - static_assert(!(Shape::kStrided % kLdsmOpInner), - "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); - - /// Shape of one individual LDSM instruction - static int const LdsmShapeStrided = - InstructionShape::kStrided / kLdsmOpInner; - static int const LdsmShapeContiguous = 4 / LdsmShapeStrided; - using LdsmShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDSM instructions - using LdsmIterations = layout::PitchLinearShape< - Shape::kContiguous / Layout::kElementsPerAccess / LdsmShapeContiguous, - 1>; - - /// Number of groups for each tile - static int const kGroupsPerTile = - Shape::kStrided / InstructionShape::kStrided; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Number of internal pointers needed to reference shared memory - static int const kPointerCount = - Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous / Layout::kFactor; - - /// Pointer type used for accesses - using AccessType = Array; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_[kPointerCount]; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), - byte_offset_(0), - k_group_idx_(0) { - - //int quad_pair = (lane_id >> 3); - int quad_quad = (lane_id >> 4); - int lane_in_pair = (lane_id & 1); - int lane_in_quad = (lane_id & 3); - int lane_in_quad_pair = (lane_id & 7); - int lane_in_quad_quad = (lane_id & 15); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount; ++i) { - int partition_contiguous_idx = -1; - int access_contiguous_idx = -1; - int access_strided_idx = -1; - - if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kA) { - // Matrix multiply 16816 A - // Q0 Q1 - // Q2 Q3 - partition_contiguous_idx = lane_in_quad / 2; - access_strided_idx = lane_in_quad_pair / Layout::kFactor + quad_quad * 2; - access_contiguous_idx = - ((lane_in_pair * 2 + ((lane_id & 8) >> 3)) ^ - access_strided_idx); - } else if (Policy::LdsmShape::kContiguous == 2 && - kOperand == Operand::kB) { - // Matrix multiply 16816 B - // Q0 Q2 - // Q1 Q3 - partition_contiguous_idx = lane_in_quad / 2; - access_strided_idx = lane_in_quad_quad / Layout::kFactor; - access_contiguous_idx = - ((lane_in_pair * 2 + quad_quad) ^ - access_strided_idx); - } else if (Policy::LdsmShape::kContiguous == 1) { - // Matrix multiply 16832.SP B - // Q0 - // Q1 - // Q2 - // Q3 - int factor_in_partition = - (Layout::PartitionShape::kContiguous * Layout::kFactor / - Layout::TileShape::kContiguous); - - partition_contiguous_idx = lane_in_quad / factor_in_partition; - access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^ - (lane_in_quad_quad / Layout::kFactor) ^ i); - access_strided_idx = lane_id / Layout::kFactor; - } - - int access_contiguous = - partition_contiguous_idx * Layout::PartitionShape::kContiguous + - access_contiguous_idx; - - int access_strided = access_strided_idx; - - pointer_[i] = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - if (Shape::kContiguous == - Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) { - if (tile_offset.contiguous() % 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount / 2; ++i) { - AccessType const *tmp_pointer = pointer_[i]; - pointer_[i] = pointer_[i + kPointerCount / 2]; - pointer_[i + kPointerCount / 2] = tmp_pointer; - } - } - contiguous_offset = (tile_offset.contiguous() >> 1) << 1; - } - - int offset = (tile_offset.strided() * InstructionShape::kStrided) * - stride_ * Layout::kElementsPerAccess / Layout::kFactor + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == Policy::kGroupsPerTile) { - k_group_idx_ = 0; - add_tile_offset( - {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); - } - } - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_[c % kPointerCount] + - Layout::TileShape::kContiguous * (c / kPointerCount) + - Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_ / Layout::kFactor; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], - source_byte_ptr - ); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_ / Layout::kFactor; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no op - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Element number when the layout crosses (in units of elements) - int Crosswise, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA, - "MmaTensorOpMultiplicandIterator for ColumnMajor Congruous may " - "only be instantiated for A operand to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// MBlock or NBlock size - static int const kCrosswise = Crosswise; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous::value, - kCrosswise>, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Element number when the layout crosses (in units of elements) - int Crosswise, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator for RowMajor Congruous may " - "only be instantiated for B operand to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Element number when the layout crosses - static int const kCrosswise = Crosswise; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous::value, - kCrosswise>, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Element number when the layout crosses (in units of elements) - int Crosswise, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCrosswise::value, - Crosswise>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for " - "A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Element number when the layout crosses - static int const kCrosswise = Crosswise; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCrosswise< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Determine number of elements along outer dimension per individual LDSM op - static int const kLdsmOpOuter = Layout::kElementsPerAccess; - static int const kLdsmOpInner = 8; - - static_assert(!(Shape::kContiguous % kLdsmOpOuter), - "Shape of warp-level mma must be divisible by LDSM's " - "fundamental tile size."); - - static_assert(!(Shape::kStrided % kLdsmOpInner), - "Shape of warp-level mma must be divisible by LDSM's " - "fundamental tile size."); - - /// Shape of one individual LDSM instruction - static int const LdsmShapeContiguous = - InstructionShape::kContiguous / kLdsmOpOuter; - static int const LdsmShapeStrided = - ((4 / LdsmShapeContiguous * kLdsmOpInner) > Shape::kStrided) - ? (Shape::kStrided / kLdsmOpInner) - : (4 / LdsmShapeContiguous); - using LdsmShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDSM instructions - using LdsmIterations = - layout::PitchLinearShape<1, Shape::kStrided / kLdsmOpInner / - LdsmShape::kStrided>; - - /// - static int const kGroupsPerTile = Layout::TileShape::kContiguous / - Layout::kFactor / LdsmShape::kContiguous; - }; - - private: - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = Array; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - - private: - - /// Total number of sections. The memory is divided into stages. One stage - /// can store one tile. Stage is divided into sections. Interleaved layout - /// can have multiple sections in a stage. The rest layout only has one section - /// in a stage. - int sections_; - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Internal counter used to determine when to increment byte offset and when - /// to XOR it - int k_group_idx_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() - : pointer_(nullptr), - sections_(0), - stride_(0), - byte_offset_(0), - k_group_idx_(0) {} - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : pointer_(reinterpret_cast(ref.data())), - sections_(ref.stride(0) / kCrosswise), - // stride_ = kCrosswise x sections_ x kFactor - stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), - byte_offset_(0), - k_group_idx_(0) { - // Warp level iterator at most use double buffer to hide latency. If there - // are more than 2 sections, every stage should have more than 1 section. - - // Turing silicon requires all 32 threads in a warp provide valid addresses - // even for LDSM.1 and LDSM.2 -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 750)) - lane_id = lane_id % (Policy::LdsmShape::kCount * Policy::kLdsmOpInner); -#endif - - int quad_quad = (lane_id >> 4); - int quad_pair = (lane_id >> 3); - int lane_in_pair = (lane_id & 1); - int lane_in_quad = (lane_id & 3); - int lane_in_quad_pair = (lane_id & 7); - int lane_in_quad_quad = (lane_id & 15); - - int partition_contiguous_idx = -1; - int access_contiguous_idx = -1; - int access_strided_idx = -1; - - if (Layout::kFactor == 8) { - int factor_in_partition = - (Layout::PartitionShape::kContiguous * Layout::kFactor / - Layout::TileShape::kContiguous); - - if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { - partition_contiguous_idx = lane_in_quad_pair / factor_in_partition; - access_contiguous_idx = ((lane_in_quad) ^ (lane_id / Layout::kFactor)); - access_strided_idx = lane_id / Layout::kFactor; - } - } else if (Layout::kFactor == 4) { - // Super Integer matrix multiply Interleaved-32 - - int factor_in_partition = - (Layout::PartitionShape::kContiguous * Layout::kFactor / - Layout::TileShape::kContiguous); - - if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { - // Integer matrix multiply 8816 A/B - partition_contiguous_idx = lane_in_quad / factor_in_partition; - access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^ - (lane_in_quad_quad / Layout::kFactor)); - access_strided_idx = lane_id / Layout::kFactor; - } - else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kA) { - // Integer matrix multiply 16832 A - partition_contiguous_idx = lane_in_quad / factor_in_partition; - access_strided_idx = lane_in_quad_quad / Layout::kFactor; - access_contiguous_idx = - ((lane_in_pair * factor_in_partition + quad_quad) ^ - access_strided_idx); - } - else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kB) { - // Integer matrix multiply 16832 B - partition_contiguous_idx = lane_in_quad / factor_in_partition; - access_strided_idx = lane_in_quad_pair / Layout::kFactor + quad_quad * 2; - access_contiguous_idx = - ((lane_in_pair * factor_in_partition + ((lane_id & 8) >> 3)) ^ - access_strided_idx); - } - } else if (Layout::kFactor == 2) { - // Super Matrix multiply kBlock = 32 - if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { - // Matrix multiply 1688 A/B - // (Q stands for 1 8x128bit block). - // Q0 - // Q1 - // Q2 - // Q3 - // Four blocks are next to each other in the strided dimension. - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor); - access_strided_idx = lane_id / Layout::kFactor; - } else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kA) { - // Matrix multiply 16816|1688.TF32 A - // Q0 Q2 - // Q1 Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = - (quad_quad ^ (lane_in_quad_pair / Layout::kFactor)); - access_strided_idx = (lane_in_quad_quad / Layout::kFactor); - } else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kB) { - // Matrix multiply 16816|1688.TF32 B - // Q0 Q1 - // Q2 Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = - ((quad_pair & 1) ^ (lane_in_quad_pair / Layout::kFactor)); - access_strided_idx = - (lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor; - } - else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { - // Matrix multiply 16832.SP B - // Q0 Q1 Q2 Q3 - partition_contiguous_idx = (lane_id % Layout::kFactor); - access_contiguous_idx = - (quad_pair ^ (lane_in_quad_pair / Layout::kFactor)); - access_strided_idx = lane_in_quad_pair / Layout::kFactor; - } - } else if (Layout::kFactor == 1) { - // Super Matrix multiply kBlock = 64 - if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { - // Q0 - // Q1 - // Q2 - // Q3 - partition_contiguous_idx = (lane_in_quad_pair >> 2); - access_contiguous_idx = lane_in_quad; - access_strided_idx = lane_id; - } - else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kA) { - // Matrix multiply 16816|1688.TF32 A - // Q0 Q2 - // Q1 Q3 - partition_contiguous_idx = (lane_in_quad_pair >> 2); - access_contiguous_idx = (quad_quad ^ lane_in_quad); - access_strided_idx = lane_in_quad_quad; - } else if (Policy::LdsmShape::kStrided == - (Policy::LdsmShape::kCount / 2) && - kOperand == Operand::kB) { - // Matrix multiply 16816|1688.TF32 B - // Q0 Q1 - // Q2 Q3 - partition_contiguous_idx = (lane_in_quad_pair >> 2); - access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad); - access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); - } - else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { - // Matrix multiply 16832.SP B - // Q0 Q1 Q2 Q3 - partition_contiguous_idx = (lane_in_quad_pair >> 2); - access_contiguous_idx = (quad_pair ^ lane_in_quad); - access_strided_idx = lane_in_quad_pair; - } - } - - int access_contiguous = - partition_contiguous_idx * Layout::PartitionShape::kContiguous + - access_contiguous_idx; - - int access_strided = access_strided_idx; - - byte_offset_ = (access_contiguous + access_strided * stride_) * - sizeof_bits::value * Layout::kElementsPerAccess / 8; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += offset * sizeof_bits::value / 8; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; - int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; - - byte_offset_ ^= k_groups_delta * sizeof_bits::value * - Layout::kElementsPerAccess * - Policy::LdsmShape::kContiguous / 8; - pointer_ += - tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + - whole_tiles * stride_ / sections_; - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( - TensorCoord const &tile_offset) { - - int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; - int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; - if (k_groups_delta < 0) { - whole_tiles -= 1; - k_groups_delta += Policy::kGroupsPerTile; - } - - if ((Policy::kGroupsPerTile / kPartitionsK) >= 2) { - byte_offset_ ^= (k_groups_delta & 1) * Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - } - if ((Policy::kGroupsPerTile / kPartitionsK) >= 4) { - byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 1)) & 2) * - Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - } - if ((Policy::kGroupsPerTile / kPartitionsK) == 8) { - byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 3)) & 4) * - Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - } - - k_group_idx_ += k_groups_delta; - whole_tiles += k_group_idx_ / (Policy::kGroupsPerTile / kPartitionsK); - k_group_idx_ = k_group_idx_ % (Policy::kGroupsPerTile / kPartitionsK); - - pointer_ += - tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + - whole_tiles * stride_ / sections_; - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator++() { - - // Integer matrix multiply 16832 Interleaved-32 - // NONE - // Integer matrix multiply 16816 Interleaved-32 || Integer matrix multiply 16816 kblock=32 - - // Integer matrix multiply 8816 Interleaved-32 - // ^1 ^1 - // Matrix multiply 1684.TF32 kblock=16 || Integer matrix multiply 16816 kblock=64 - // Matrix multiply 1688 kblock=32 || Integer matrix multiply 8816 kblock=64 - // ^1 ^3 ^1 ^3 - // Matrix multiply 1688 kblock=64 - // ^1 ^3 ^1 ^7 ^1 ^3 ^1 ^7 - - // Matrix multiply 16816 kblock=32 | 1688.TF32 kblock=16 || Integer matrix multiply 16832 kblock=64 - // ^2 ^2 - // Matrix multiply 16816 kblock=64 | 1688.TF32 kblock=32 || Integer matrix multiply 16832 kblock=128 - // ^2 ^6 ^2 ^6 - - if ((Policy::kGroupsPerTile / kPartitionsK) > 1) { - int mask = ((Policy::kGroupsPerTile / kPartitionsK) == 8) - ? 3 - : (((Policy::kGroupsPerTile / kPartitionsK) == 4) ? 1 : 0); - - if (((k_group_idx_ & mask) % 2) == 0) - byte_offset_ ^= 1 * Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - else if ((k_group_idx_ & mask) == 1) - byte_offset_ ^= 3 * Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - else if ((k_group_idx_ & mask) == 3) - byte_offset_ ^= 7 * Policy::LdsmShape::kContiguous * - sizeof_bits::value * - Layout::kElementsPerAccess / 8; - } - - k_group_idx_++; - - if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { - k_group_idx_ = 0; - add_tile_offset({Policy::kGroupsPerTile, 0}); - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator--() { assert(0); } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_ + Policy::LdsmShape::kContiguous * c + - Policy::kLdsmOpInner / Layout::kFactor * - Policy::LdsmShape::kStrided * s * stride_; - - char const *source_byte_ptr = - reinterpret_cast(source_ptr) + byte_offset + - byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], source_byte_ptr); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = tile_offset.contiguous() * - InstructionShape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_; - - byte_offset += sizeof_bits::value * pointer_offset / 8; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Element number when the layout crosses (in units of elements) - int Crosswise, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator for ColumnMajor Crosswise may " - "only be instantiated for B operand to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// KBlock size - static int const kCrosswise = Crosswise; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCrosswise::value, - kCrosswise>, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - private: - /// Underlying tile iterator - Base iterator_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : iterator_({ref.data(), ref.stride()}, lane_id) {} - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator++() { - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator--() { - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { iterator_.load(frag); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Element number when the layout crosses (in units of elements) - int Crosswise, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA, - "MmaTensorOpMultiplicandIterator for RowMajor Crosswise may " - "only be instantiated for A operand to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Element number when the layout crosses - static int const kCrosswise = Crosswise; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, kCrosswise>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCrosswise::value, - kCrosswise>, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - private: - /// Underlying tile iterator - Base iterator_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : iterator_({ref.data(), ref.stride()}, lane_id) {} - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator++() { - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &operator--() { - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { iterator_.load(frag); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpAccumulatorTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -/// accumulator layout. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpAccumulatorTileIterator< - Shape_, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static bool const kDivisible = - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, - (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN - >; - }; - -private: - - // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire - // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements - // of that row. The accumulators within one row are assumed to be consecutive. - static int const kElementsPerAccess = InstructionShape::kN / 4; - static int const kRowsPerTile = 8; - static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - - frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - offset_ref.at({accum_m, accum_n}) = frag[idx]; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. -/// -/// This iterator is not tested. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpAccumulatorTileIterator< - Shape_, Element_, cutlass::layout::AffineRankN<2>, InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static bool const kDivisible = - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, - (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN - >; - }; - -private: - - // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire - // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements - // of that row. The accumulators within one row are assumed to be consecutive. - static int const kElementsPerAccess = InstructionShape::kN / 4; - static int const kRowsPerTile = 8; - static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - - frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - offset_ref.at({accum_m, accum_n}) = frag[idx]; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -/// accumulator layout. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaTensorOpAccumulatorTileIterator { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static bool const kDivisible = - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, - (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN - >; - }; - -private: - - // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire - // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements - // of that row. The accumulators within one row are assumed to be consecutive. - static int const kElementsPerAccess = InstructionShape::kN / 4; - static int const kRowsPerTile = 8; - static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - frag[idx] = offset_ref.at({accum_m, accum_n}); - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = kAccumulatorRows * kElementsPerAccess * - (mma_n * Policy::MmaIterations::kRow + mma_m); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < kAccumulatorRows; ++row) { - CUTLASS_PRAGMA_UNROLL - for (int col = 0; col < kElementsPerAccess; ++col) { - int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + - row * kRowsPerTile; - int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; - int idx = mma_accum_start + row * kElementsPerAccess + col; - - offset_ref.at({accum_m, accum_n}) = frag[idx]; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -/// accumulator layout. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element typ - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_, - /// Interleaved N - int InterleavedN> -class MmaTensorOpAccumulatorTileIterator< - Shape_, Element_, cutlass::layout::ColumnMajorInterleaved, - InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorInterleaved; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape; - }; - -private: - - static int const kElementsPerAccess = 2; - -public: - - // - // Derived quantities - // - - using AccessType = Array; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - int accum_m = mma_m * InstructionShape::kM; - int accum_n = mma_n * InstructionShape::kN; - - int idx = mma_m + mma_n * Policy::MmaIterations::kRow; - - AccessType* access_ptr = reinterpret_cast(offset_ref.data() + - offset_ref.offset(TensorCoord(accum_m, accum_n))); - - frag_ptr[idx] = access_ptr[0]; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - int accum_m = mma_m * InstructionShape::kM; - int accum_n = mma_n * InstructionShape::kN; - - int idx = mma_m + mma_n * Policy::MmaIterations::kRow; - - AccessType* access_ptr = reinterpret_cast(offset_ref.data() + - offset_ref.offset(TensorCoord(accum_m, accum_n))); - - access_ptr[0] = frag_ptr[idx]; - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -/// accumulator layout. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element typ - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_, - /// Interleaved N - int InterleavedN> -class MmaTensorOpAccumulatorTileIterator< - Shape_, Element_, cutlass::layout::TensorNCxHWx, - InstructionShape_, OpDelta_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = int8_t; - - /// Layout of source tile - using Layout = cutlass::layout::TensorNCxHWx; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of elements in strided dimension that each STG writes - static int const kStridedPerSTG = 8; - - /// Factor to calculate reorder index to pack accumulator. - static int const kPackedFactor = Shape::kColumn / 32; - - /// Number of mma operations performed - using MmaIterations = MatrixShape; - }; - -private: - - static int const kElementsPerAccess = InterleavedN / 4; - -public: - - // - // Derived quantities - // - - struct alignas((kElementsPerAccess * sizeof_bits::value / 8)) AccessType { - Array storage; - }; - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Reference to output tensor - TensorRef ref_; - - /// Row offset index globally - LongIndex global_offset_row_; - - /// Column offset index globally - LongIndex global_offset_col_; - - /// Output tensor size - TensorCoord extent_; - - /// Alpha - float alpha_; - - /// Beta - float beta_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int const lane_id, - TensorCoord extent, - float alpha = 1.0f, - float beta = 0.0f - ): - ref_(ref), - extent_(extent), - alpha_(alpha), - beta_(beta) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - - global_offset_row_ = quad; - - global_offset_col_ = lane_in_quad * kElementsPerAccess; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator &add_tile_offset(MatrixCoord const &tile_offset) { - - global_offset_row_ += tile_offset.row() * Shape::kRow; - - global_offset_col_ += tile_offset.column() * Shape::kColumn; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kN; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kM; ++mma_m) { - int accum_m = mma_m * InstructionShape::kM; - int accum_n = mma_n * InstructionShape::kN; - - int idx = mma_m + mma_n * Policy::MmaIterations::kM; - - AccessType* access_ptr = reinterpret_cast(offset_ref.data() + - accum_m * offset_ref.stride(0) + accum_n); - - frag_ptr[idx] = access_ptr[0]; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - Array output_frag_f; - Array output_frag; - - LongIndex pq = extent_.h() * extent_.w(); - - LongIndex extent_row = extent_.n() * pq; - LongIndex extent_col = extent_.c(); - - LongIndex k_major = (global_offset_col_ / InterleavedN) * pq; - Index k_minor = global_offset_col_ % InterleavedN; - LongIndex k_offset = k_major * InterleavedN + k_minor; - LongIndex k_offset_delta = pq * InterleavedN; - - LongIndex stride_n = pq * extent_.c(); - - Index n; - LongIndex pq_rem; - - unsigned int pq_mul, pq_shr; - find_divisor(pq_mul, pq_shr, pq); - - if(beta_ == 0.0f) { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < int(frag.size()); ++i) { - output_frag_f[i] = frag[i]; - } - - if(InstructionShape::kM == Policy::kStridedPerSTG) { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < int(frag.size()); ++i) { - output_frag[i] = (Element)(output_frag_f[i] * alpha_); - } - } else { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < int(frag.size()); ++i) { - int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) - + (i % (8 * Policy::kPackedFactor)) / 2 * 4 - + (i % (8 * Policy::kPackedFactor)) % 2 - + (i / (8 * Policy::kPackedFactor)) % 2 * 2; - output_frag[i] = (Element)(output_frag_f[map_i] * alpha_); - } - } - - AccessType const *frag_ptr = reinterpret_cast(&output_frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - int accum_m = mma_m * Policy::kStridedPerSTG; - - fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); - LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - - int accum_n = mma_n * InterleavedN; - - int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; - - if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { - AccessType* access_ptr = reinterpret_cast(offset_ref.data() + - offset_m + mma_n * k_offset_delta); - - access_ptr[0] = frag_ptr[idx]; - } - } - } - } else { - if(InstructionShape::kM == Policy::kStridedPerSTG) { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < int(frag.size()); ++i) { - output_frag_f[i] = frag[i]; - } - } else { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < int(frag.size()); ++i) { - int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) - + (i % (8 * Policy::kPackedFactor)) / 2 * 4 - + (i % (8 * Policy::kPackedFactor)) % 2 - + (i / (8 * Policy::kPackedFactor)) % 2 * 2; - output_frag_f[i] = frag[map_i]; - } - } - - AccessType const *frag_ptr = reinterpret_cast(&output_frag); - - Array ref_frag; - AccessType *ref_frag_ptr = reinterpret_cast(&ref_frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - int accum_m = mma_m * Policy::kStridedPerSTG; - - fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); - LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - - int accum_n = mma_n * InterleavedN; - - int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; - - if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { - AccessType* access_ptr = reinterpret_cast(offset_ref.data() + - offset_m + mma_n * k_offset_delta); - - ref_frag_ptr[0] = access_ptr[0]; - - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < kElementsPerAccess; ++i) { - output_frag[idx * kElementsPerAccess + i] = Element(alpha_ * output_frag_f[idx * kElementsPerAccess + i] - + beta_ * ref_frag[i]); - } - - access_ptr[0] = frag_ptr[idx]; - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h deleted file mode 100644 index 0d1da845ca08e1999403c5e34260b8e54bb6a85c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ /dev/null @@ -1,3096 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm70.h" - -#include "cutlass/platform/platform.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads> -class MmaVoltaTensorOpMultiplicandTileIterator; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kA, Element_, - cutlass::layout::VoltaTensorOpMultiplicandCongruous< - sizeof_bits::value>, - InstructionShape_, OpDelta_, 32> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::VoltaTensorOpMultiplicandCongruous::value>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Shape of one individual LDS.128 - using LdsShape = layout::PitchLinearShape< - 32, - 4 - >; - - // LdsShapes are arranged in the strided direction in SMEM - using LdsIterations = layout::PitchLinearShape< - InstructionShape::kStrided / LdsShape::kStrided, - Shape::kContiguous / LdsShape::kContiguous - >; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Number of internal pointers needed to reference shared memory - static int const kPointerCount = 2; - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_[kPointerCount]; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { - // swizzle patterns for operandA LDS are - // 1. (tid[4] << 3) | (tid[2:0] ^ tid[4]) - // 2. (tid[4] << 3) | (tid[2:0] ^ tid[4] ^ 0b10010) - - int vec_row = (lane_id >> 4); // tid[4] - int vec_col = ((lane_id & 4) >> 2); // tid[2] - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPointerCount; ++i) { - - if(i == 1) { - vec_row |= 2; - } - int access_contiguous_idx = (vec_col << 2) | ((lane_id & 3) ^ vec_row); - int access_contiguous = access_contiguous_idx; - - int access_strided = vec_row; - pointer_[i] = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - int strided_offset = tile_offset.strided(); - - // To support 32x32 tile size - if (Shape::kContiguous == Policy::LdsShape::kContiguous) { - if (contiguous_offset % 2) { - AccessType const *tmp_pointer = pointer_[0]; - pointer_[0] = pointer_[1]; - pointer_[1] = tmp_pointer; - } - contiguous_offset = contiguous_offset / 2 * 2; - } - - int offset = (strided_offset * InstructionShape::kStrided) * stride_ * - Layout::kElementsPerAccess + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator++() { - byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator--() { - byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType * fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsIterations::kContiguous; - - AccessType const *source_ptr = pointer_[s & 1] + - Policy::LdsShape::kContiguous * c + - Policy::LdsShape::kStrided * (s / 2) * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> - -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kB, Element_, - cutlass::layout::VoltaTensorOpMultiplicandBCongruous< - sizeof_bits::value>, - InstructionShape_, OpDelta_, 32> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::VoltaTensorOpMultiplicandBCongruous::value>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kContiguous % InstructionShape::kContiguous), - "Shape of warp-level Mma must be divisible by operator shape."); - - // Shape of one individual LDS - using LdsShape = layout::PitchLinearShape< - 32, - 4 - >; - - using LdsIterations = layout::PitchLinearShape< - Shape::kContiguous / LdsShape::kContiguous, - InstructionShape::kStrided / LdsShape::kStrided - >; - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile, needs on more time number of registers - using Fragment = Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { - - // swizzle pattern is (tid & (3 << 3) | (tid[1:0] ^ tid[4:3])) - int access_strided = (lane_id >> 3) & 0x3; - int access_contiguous = ((lane_id ^ (lane_id >> 3)) & 0x3); - - pointer_ = reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - int strided_offset = tile_offset.strided(); - - int offset = (strided_offset * InstructionShape::kStrided) * stride_ * - Layout::kElementsPerAccess + - contiguous_offset * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator++() { - byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator--() { - byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * - Layout::kElementsPerAccess; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType * fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsIterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::LdsShape::kContiguous / Layout::kElementsPerAccess * c + - Policy::LdsShape::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kA, Element_, - cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< - sizeof_bits::value>, - InstructionShape_, OpDelta_, 32> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaVoltaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::VoltaTensorOpMultiplicandCongruous::value>, - layout::PitchLinearShape, - kOpDelta, kThreads>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kB, Element_, - cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous< - sizeof_bits::value>, - InstructionShape_, OpDelta_, 32> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaVoltaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::VoltaTensorOpMultiplicandBCongruous::value>, - layout::PitchLinearShape, - kOpDelta, kThreads>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -/// accumulator layout. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions, concept: MatrixShape) - typename OpDelta_> -class MmaVoltaTensorOpAccumulatorTileIterator { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kC; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - - /// Volta Tensor Op uses 32x32 interleaved tile - using InterleavedTile = MatrixShape<32, 32>; - - static_assert(!(Shape::kRow % InterleavedTile::kRow) && !(Shape::kColumn % InterleavedTile::kColumn), - "Shape of warp-level Mma must be divisible by operator shape."); - - static_assert(platform::is_same::value, - "Layouts must be defined for logical MatrixCoord coordinate space."); - - /// Number of mma operations performed - using TileIterations = MatrixShape< - Shape::kRow / InterleavedTile::kRow, - Shape::kColumn / InterleavedTile::kColumn - >; - - using MmaIterations = - MatrixShape; - }; - -private: - - // Assume accumulator tile is multipile interleaved 32x32 tile. - static int const kElementsPerPartial = 4; - using EleShapePerPatial = typename platform::conditional< - platform::is_same::value, - MatrixShape<2, 2>, - MatrixShape<1, 4> >::type; - static int const kElementsPerMma = 8; - static int const kAccumulatorPatials = 2; - using QuadShapePerPatialMma = MatrixShape<4, 4>; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - -private: - - /// Reference to output tensor - TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): - ref_(ref) { - - int quad = (lane_id >> 2); - int lane_in_quad = (lane_id & 3); - int accum_m, accum_n; - - if (platform::is_same::value) { - // (quad[2],quad[0])+lane_in_quad[0] - accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); - // (quad[1])+lane_in_quad[1] - accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + - (lane_in_quad & 2); - } else { - accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) - accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; - } - MatrixCoord lane_offset(accum_m, accum_n); - - ref_.add_coord_offset(lane_offset); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator & operator++() { - // deliberate no-op - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator & operator--() { - // deliberate no-op - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_HOST_DEVICE - void load_with_pointer_offset( - Fragment &frag, ///< fragment to load from the tensor - Index pointer_offset) const { ///< loads a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { - CUTLASS_PRAGMA_UNROLL - for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = - (((tile_n * Policy::TileIterations::kRow + tile_m) * - Policy::MmaIterations::kColumn + mma_n) * - Policy::MmaIterations::kRow + mma_m) * - kElementsPerMma; - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < kAccumulatorPatials; ++p) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < EleShapePerPatial::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { - int accum_m = tile_m * Policy::InterleavedTile::kRow + - mma_m * QuadShapePerPatialMma::kRow + m * 2; - int accum_n = tile_n * Policy::InterleavedTile::kColumn + - mma_n * QuadShapePerPatialMma::kColumn + - p * Policy::InterleavedTile::kColumn/2 + n; - int idx = mma_accum_start + p * kElementsPerPartial + - m * EleShapePerPatial::kColumn + n; - frag[idx] = offset_ref.at({accum_m, accum_n}); - } - } - } - } - } - } - } - } - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - Fragment &frag, ///< fragment to load from the tensor - Index byte_offset) const { ///< loads a tile with a linear offset - - load_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_HOST_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles - - load(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_HOST_DEVICE - void load( - Fragment &frag, ///< fragment to load from the tensor - TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles - Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset - - load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } - - /// Stores a fragment to memory - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_HOST_DEVICE - void store_with_pointer_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index pointer_offset) const { ///< store a tile with a linear offset - - TensorRef offset_ref(ref_); - offset_ref.add_pointer_offset(pointer_offset); - - CUTLASS_PRAGMA_UNROLL - for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { - CUTLASS_PRAGMA_UNROLL - for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { - - int mma_accum_start = - (((tile_n * Policy::TileIterations::kRow + tile_m) * - Policy::MmaIterations::kColumn + mma_n) * - Policy::MmaIterations::kRow + mma_m) * - kElementsPerMma; - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < kAccumulatorPatials; ++p) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < EleShapePerPatial::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { - int accum_m = tile_m * Policy::InterleavedTile::kRow + - mma_m * QuadShapePerPatialMma::kRow + m * 2; - int accum_n = tile_n * Policy::InterleavedTile::kColumn + - mma_n * QuadShapePerPatialMma::kColumn + - p * Policy::InterleavedTile::kColumn/2 + n; - int idx = mma_accum_start + p * kElementsPerPartial + - m * EleShapePerPatial::kColumn + n; - offset_ref.at({accum_m, accum_n}) = frag[idx]; - } - } - } - } - } - } - } - } - - /// Stores a fragment to memory with additional pointer offset - CUTLASS_HOST_DEVICE - void store_with_byte_offset( - Fragment const &frag, ///< fragment to store from the tensor - Index byte_offset) const { ///< store a tile with a linear offset - - store_with_pointer_offset(byte_offset / sizeof(Element)); - } - - /// Stores a fragment to memory with logical offset in units of whole tiles. - CUTLASS_HOST_DEVICE - void store( - Fragment &frag, ///< fragment to store to the tensor - TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles - - store(frag, tile_offset, 0); - } - - /// Stores a fragment from memory with logical offset in units of whole tiles. - CUTLASS_HOST_DEVICE - void store( - /// fragment to store to the tensor - Fragment const &frag, - /// stores a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// stores a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); - } -}; - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// KBlock size (in units of elements) - int KBlock> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::VoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, KBlock>, - InstructionShape_, OpDelta_, 32> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand == Operand::kB, - "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for " - "A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// KBlock size - static int const kKBlock = KBlock; - - /// Layout of source tile - using Layout = cutlass::layout::VoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, kKBlock>; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - - /// Shape of one individual LDS instruction - using LdsShape = layout::PitchLinearShape<1, 32>; - - /// Number and arrangement of LDSM instructions - using LdsIterations = layout::PitchLinearShape<1, Shape::kStrided / 32>; - - /// Using LDS.128 - static int const kElementsPerAccess = 8; - - /// Contiguous elements per line - static int const kContiguousElementsPerLine = 4; - }; - - private: - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - - private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Crosswised elements are arranged in a SMEM line - /// in units of AccessType - Index line_size; - - /// Internal counter used to determine load addr offset - /// and when to swap higher 64bit with lower 64bit - int k_group_idx_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator() - : pointer_(nullptr), - stride_(0), - line_size(0), - byte_offset_(0), - k_group_idx_(0) {} - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : pointer_(reinterpret_cast(ref.data())), - stride_(ref.stride(0) * Policy::kElementsPerAccess), - line_size((ref.stride(0) * Policy::kContiguousElementsPerLine) / - Policy::kElementsPerAccess), - k_group_idx_(0), - byte_offset_(0) { - - int quad = (lane_id / 4); - int lane_in_quad = (lane_id % 4); - int access_contiguous; - - if(kOperand == Operand::kA) { - - // swizzle id: tid[4]|tid[1:0]|(tid[2]^tid[4]) - access_contiguous = ((quad & 0x4) << 1) + ((lane_in_quad) << 1) + - ((quad & 0x1) ^ ((quad & 0x4) >> 2)); - } else { - - // swizzle id: tid[4]|tid[1:0]|tid[3] - access_contiguous = ((quad & 0x4) << 1) + (lane_in_quad << 1) + - ((quad & 0x2) >> 1 ^ ((quad & 0x4) >> 2)); - } - - byte_offset_ = access_contiguous * - sizeof(Element) * Policy::kElementsPerAccess; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - - int contiguous_offset = tile_offset.contiguous(); - int strided_offset = tile_offset.strided(); - k_group_idx_ = 0; - - pointer_ += contiguous_offset * - (InstructionShape::kContiguous / - Policy::kContiguousElementsPerLine) * - line_size + - strided_offset * Shape::kStrided / 2; - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator++() { - k_group_idx_ = (k_group_idx_ + 1) % 8; - - if (k_group_idx_ == 4 || k_group_idx_ == 0) { - byte_offset_ ^= 1 * sizeof(Element) * Policy::kElementsPerAccess; - } - - pointer_ += line_size; - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator--() { assert(0); } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType * fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsIterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::LdsShape::kContiguous * c * line_size + - Policy::LdsShape::kStrided * s / 2; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); - - // swap higher 64bit and lower 64bit - if (k_group_idx_ & 0x2) { - uint64_t *low = reinterpret_cast(&frag) + access_idx * 2; - uint64_t *high = reinterpret_cast(&frag) + access_idx * 2 + 1; - uint64_t tmp = *low; - *low = *high; - *high = tmp; - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = tile_offset.contiguous() * - InstructionShape::kContiguous / - Policy::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - k_group_idx_ = k_group; - } -}; - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// KBlock size (in units of elements) - int KBlock> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, KBlock>, - InstructionShape_, OpDelta_, 32> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for " - "A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// KBlock size - static int const kKBlock = KBlock; - - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, kKBlock>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaVoltaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::VoltaTensorOpMultiplicandCrosswise::value, - kKBlock>, - layout::PitchLinearShape, - kOpDelta, kThreads>; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - private: - /// Underlying tile iterator - Base iterator_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : iterator_({ref.data(), ref.stride()}, lane_id) {} - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator++() { - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator--() { - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { iterator_.load(frag); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// KBlock size (in units of elements) - int KBlock> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, KBlock>, - InstructionShape_, OpDelta_, 32> { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand == Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for " - "A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// KBlock size - static int const kKBlock = KBlock; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, kKBlock>; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaVoltaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::VoltaTensorOpMultiplicandCrosswise::value, - kKBlock>, - layout::PitchLinearShape, - kOpDelta, kThreads>; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - private: - /// Underlying tile iterator - Base iterator_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) - : iterator_({ref.data(), ref.stride()}, lane_id) {} - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator++() { - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator--() { - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { iterator_.load(frag); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for 'TN' arrangement -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand_, - /// Data type of A elements - typename Element_, - /// Layout of matrix operand - typename Layout_, - /// Shape of one matrix production operation (concept: MatrixShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads = 32, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - /// Basic check - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Number of elements accessed per Shared Memory load - static int const kElementsPerAccess = 4; - -private: - - static int const kInterleavedTileRows = 32; - static int const kInterleavedTileColumns = 32; - static int const kInstructionsPerTile = 2; - - /// Rounded up instruction counts - using TileCount = MatrixShape< - Shape::kRow / kInterleavedTileRows, - Shape::kColumn / kInterleavedTileColumns - >; - - using FragmentCount = MatrixShape< - TileCount::kRow * kInstructionsPerTile, - TileCount::kColumn * kInstructionsPerTile - >; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess - >; - - /// Memory access type - using AccessType = AlignedArray; - -private: - - /// Underlying tensor reference - TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to conditionally enable extents checking - bool divisible_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(): divisible_(true) { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( - TensorRef const &ref, - int lane_id - ): - ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { - - int quad_id = lane_id / 4; - int lane_in_quad = (lane_id % 4); - - if (kOperand == Operand::kA) { - - int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; - int col_idx = 0; - - origin_ = MatrixCoord(row_idx, col_idx); - } - else { - - int row_idx = 0; - int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; - - origin_ = MatrixCoord(row_idx, col_idx); - } - - ref_.add_coord_offset(origin_); - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( - TensorRef const &ref, - TensorCoord extent, - int lane_id - ): ref_(ref), extent_(extent), divisible_(false) { - - int quad_id = lane_id / 4; - int lane_in_quad = (lane_id % 4); - - if (kOperand == Operand::kA) { - - int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; - int col_idx = 0; - - origin_ = MatrixCoord(row_idx, col_idx); - } - else { - - int row_idx = 0; - int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; - - origin_ = MatrixCoord(row_idx, col_idx); - } - - #if defined(__CUDA_ARCH__) - __syncthreads(); - #endif - - ref_.add_coord_offset(origin_); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_pointer_offset(LongIndex offset) { - - ref_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_tile_offset(TensorCoord const &tile_offset) { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator++() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, 1}); - } - else { - add_tile_offset({1, 0}); - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator--() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, -1}); - } - else { - add_tile_offset({-1, 0}); - } - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - AccessType *frag_ptr = reinterpret_cast(&frag); - AccessType const *access_ptr = reinterpret_cast(ref_.data()); - int ldm = ref_.stride()[0]; - - if (kOperand == Operand::kA) { - - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < FragmentCount::kRow; ++idx) { - - int tile_idx = idx / 2; - int quad_idx = idx % 2; - - int row_offset = tile_idx * kInterleavedTileRows + quad_idx * 4; - frag_ptr[idx] = access_ptr[row_offset * ldm / kElementsPerAccess]; - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { - - int tile_idx = idx / 2; - int quad_idx = idx % 2; - - int col_offset = tile_idx * kInterleavedTileColumns + quad_idx * 4; - frag_ptr[idx] = access_ptr[col_offset * ldm / kElementsPerAccess]; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - - load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation - } -}; - - -/// Tile iterator specialized for 'NT' arrangement -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand_, - /// Data type of A elements - typename Element_, - /// Layout of matrix operand - typename Layout_, - /// Shape of one matrix production operation (concept: MatrixShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads = 32, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - /// Basic check - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Number of elements accessed per Shared Memory load - static int const kElementsPerAccess = 4; - -private: - - static int const kInterleavedTileRows = 32; - static int const kInterleavedTileColumns = 32; - static int const kInstructionsPerTile = 2; - - /// Rounded up instruction counts - using TileCount = MatrixShape< - Shape::kRow / kInterleavedTileRows, - Shape::kColumn / kInterleavedTileColumns - >; - - using FragmentCount = MatrixShape< - TileCount::kRow * kInstructionsPerTile, - TileCount::kColumn * kInstructionsPerTile - >; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess - >; - - /// Memory access type - using AccessType = AlignedArray; - -private: - - /// Underlying tensor reference - TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to conditionally enable extents checking - bool divisible_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(): divisible_(true) { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( - TensorRef const &ref, - int lane_id - ): - ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { - - int quad_id = lane_id / 4; - int lane_in_quad = (lane_id % 4); - - if (kOperand == Operand::kA) { - - int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; - int col_idx = lane_in_quad; - - origin_ = MatrixCoord(row_idx, col_idx); - } - else { - - int row_idx = lane_in_quad; - int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; - - origin_ = MatrixCoord(row_idx, col_idx); - } - - ref_.add_coord_offset(origin_); - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( - TensorRef const &ref, - TensorCoord extent, - int lane_id - ): ref_(ref), extent_(extent), divisible_(false) { - - int quad_id = lane_id / 4; - int lane_in_quad = (lane_id % 4); - - if (kOperand == Operand::kA) { - - int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; - int col_idx = lane_in_quad; - - origin_ = MatrixCoord(row_idx, col_idx); - } - else { - - int row_idx = lane_in_quad; - int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; - - origin_ = MatrixCoord(row_idx, col_idx); - } - - #if defined(__CUDA_ARCH__) - __syncthreads(); - #endif - - ref_.add_coord_offset(origin_); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_pointer_offset(LongIndex offset) { - - ref_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_tile_offset(TensorCoord const &tile_offset) { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator++() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, 1}); - } - else { - add_tile_offset({1, 0}); - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator--() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, -1}); - } - else { - add_tile_offset({-1, 0}); - } - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - AccessType *frag_ptr = reinterpret_cast(&frag); - AccessType const *access_ptr = reinterpret_cast(ref_.data()); - int ldm = ref_.stride()[0]; - - if (kOperand == Operand::kA) { - - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < FragmentCount::kRow; ++idx) { - - int tile_idx = idx / 2; - int quad_idx = idx % 2; - - int row_offset = tile_idx * kInterleavedTileRows; - frag_ptr[idx] = access_ptr[row_offset / kElementsPerAccess + quad_idx]; - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { - - int tile_idx = idx / 2; - int quad_idx = idx % 2; - - int col_offset = tile_idx * kInterleavedTileColumns; - frag_ptr[idx] = access_ptr[col_offset / kElementsPerAccess + quad_idx]; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - - load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, - Operand::kA, - Element_, - cutlass::layout::RowMajor, - InstructionShape_, - OpDelta_, - 32 -> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< - Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { - -public: - using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< - Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> ; - - using TensorRef = typename Base::TensorRef; - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): Base(ref, lane_id) { } - -}; - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, - Operand::kA, - Element_, - cutlass::layout::ColumnMajor, - InstructionShape_, - OpDelta_, - 32 -> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< - Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { - -public: - using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< - Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> ; - - using TensorRef = typename Base::TensorRef; - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): Base(ref, lane_id) { } - -}; - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kB, Element_, - cutlass::layout::ColumnMajor, - InstructionShape_, OpDelta_, 32 -> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< - Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { - -public: - using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< - Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_>; - - using TensorRef = typename Base::TensorRef; - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): Base(ref, lane_id) { } -}; - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_> -class MmaVoltaTensorOpMultiplicandTileIterator< - Shape_, Operand::kB, Element_, - cutlass::layout::RowMajor, - InstructionShape_, OpDelta_, 32 -> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< - Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { - -public: - using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< - Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_>; - - using TensorRef = typename Base::TensorRef; - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaVoltaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): Base(ref, lane_id) { } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h deleted file mode 100644 index a5370ff8f14a3e384da392782cdc26c1f34a4eff..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ /dev/null @@ -1,2440 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 64b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicandCongruous64b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 4), "Divisibility."); - - static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 2; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<8, 4>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, - InstructionShape::kStrided / Delta::kStrided - >; - - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - - /// Internal counter used to jump to next K partition - int k_group_idx_; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), - k_group_idx_(0) { - - int access_strided = lane_id / Policy::Delta::kContiguous; - int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; - - pointer_= reinterpret_cast(ref.data()) + - access_contiguous + access_strided * stride_; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += offset * sizeof(Element); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - int offset = - (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + - tile_offset.contiguous() * Shape::kContiguous; - - add_pointer_offset(offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - add_tile_offset({0, 1}); - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - add_tile_offset({0, -1}); - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::Iterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c + - Policy::Delta::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - Index pointer_offset = - tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous64b, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicandCongruous64b, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for loading 128b vectors of 64b elements. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::TensorOpMultiplicand64bCrosswise, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); - - static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Long Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Load two elements per access - static int const kElementsPerAccess = 2; - - /// Policy defining internal details of tile iterator - struct Policy { - - /// Shape of one access - using Delta = layout::PitchLinearShape<4, 16>; - - /// Number of iterations to load - using Iterations = layout::PitchLinearShape< - InstructionShape::kContiguous / Delta::kContiguous, - Shape::kStrided / Delta::kStrided - >; - - }; - -private: - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = AlignedArray; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - -private: - - /// Layout object storing stride values - StrideIndex stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Internal counter for tracking K-group - Index k_group_idx_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): - stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), - k_group_idx_(0) { - - int access_strided = lane_id / 8; - int access_contiguous = (lane_id % 8); - - byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); - - pointer_= reinterpret_cast(ref.data()); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - pointer_ += offset / kElementsPerAccess; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * - stride_ * kElementsPerAccess + - tile_offset.strided() * Shape::kStrided; - - add_pointer_offset(offset); - - int old_k_group_idx = k_group_idx_; - - k_group_idx_ += tile_offset.contiguous(); - - if ((k_group_idx_ & 2) ^ (old_k_group_idx & 2)) { - byte_offset_ ^= 0x40; - } - - return *this; - } - - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - - add_tile_offset(tile_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - pointer_ += stride_ * InstructionShape::kContiguous; - - if (k_group_idx_ & 0x1) { - // xor ptr - byte_offset_ ^= 0x40; - } - - ++k_group_idx_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - AccessType *fetch_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::Iterations::kStrided; ++s) { - - int access_idx = c + s * Policy::Iterations::kContiguous; - - AccessType const *source_ptr = pointer_ + - Policy::Delta::kContiguous * c * stride_ + - Policy::Delta::kStrided * s / kElementsPerAccess; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; - - AccessType const *source = reinterpret_cast(source_byte_ptr); - - fetch_ptr[access_idx] = *source; - } - } - - Element *exchange_ptr = reinterpret_cast(&frag); - - if (k_group_idx_ & 1) { - // exchange on 64b granularity - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Fragment::kElements; i += 2) { - Element tmp = exchange_ptr[i]; - exchange_ptr[i] = exchange_ptr[i + 1]; - exchange_ptr[i + 1] = tmp; - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = tile_offset.contiguous() * - InstructionShape::kContiguous / - Layout::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - k_group_idx_ = k_group; - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicand64bCrosswise, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.strided(), tile_offset.contiguous()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIterator< - layout::PitchLinearShape, kOperand, Element, - layout::TensorOpMultiplicand64bCrosswise, - layout::PitchLinearShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - - -/// Tile iterator specialized for canonical matrix layouts -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand_, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: MatrixShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads = 32, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class MmaTensorOpMultiplicandTileIteratorCanonical { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - /// Basic check - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Number of elements accessed per Shared Memory load - static int const kElementsPerAccess = - (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); - -private: - - static int const kWarpShapeOuter = - (kOperand == Operand::kA ? Shape::kRow : Shape::kColumn); - - static int const kWarpShapeInner = - (kOperand == Operand::kA ? Shape::kColumn : Shape::kRow); - - - /// Rounded up instruction counts - using InstructionCount = MatrixShape< - Shape::kRow / InstructionShape::kRow, - Shape::kColumn / InstructionShape::kColumn - >; - - /// Rounded up tile dimensions - using WarpShapeDivisible = MatrixShape< - InstructionCount::kRow * InstructionShape::kRow, - InstructionCount::kColumn * InstructionShape::kColumn - >; - -public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array< - Element, - WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads - >; - - /// Memory access type - using AccessType = AlignedArray; - -private: - - /// Underlying tensor reference - TensorRef ref_; - - /// Extent of tensor - MatrixCoord extent_; - - /// Origin - MatrixCoord origin_; - - /// Used to conditionally enable extents checking - bool divisible_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical( - TensorRef const &ref, - int lane_id - ): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { - - if (kOperand == Operand::kA) { - origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); - } - else { - origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); - } - - ref_.add_coord_offset(origin_); - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical( - TensorRef const &ref, - TensorCoord extent, - int lane_id - ): ref_(ref), extent_(extent), divisible_(false) { - - if (kOperand == Operand::kA) { - origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); - } - else { - origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); - } - - ref_.add_coord_offset(origin_); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) { - - ref_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - origin_ += coord_offset; - - ref_.add_coord_offset(coord_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical & operator++() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, 1}); - } - else { - add_tile_offset({1, 0}); - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical & operator--() { - - if (kOperand == Operand::kA) { - add_tile_offset({0, -1}); - } - else { - add_tile_offset({-1, 0}); - } - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - load_with_pointer_offset(frag, 0); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - int const kWarpShapeDivisibleInner = - (kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow); - - // Take advantage of Tensor Op's 8 x 4T access pattern - int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; - - AccessType *access_ptr = reinterpret_cast(&frag); - - if (kOperand == Operand::kA) { - int const kTilesPerInstruction = InstructionShape::kRow / 8; - - CUTLASS_PRAGMA_UNROLL - for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { - int access_idx = - access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); - - MatrixCoord offset( - access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, - inner_idx * 4 * kElementsPerAccess); - - MatrixCoord access_coord = origin_ + offset; - - if (divisible_ || - (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { - - access_ptr[access_idx] = *reinterpret_cast( - ref_.data() + ref_.offset(offset)); - } - else { - AccessType zero; - zero.clear(); - access_ptr[access_idx] = zero; - } - } - } - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { - - CUTLASS_PRAGMA_UNROLL - for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { - int access_idx = inner_idx + kAccessesInner * inst_n_idx; - - MatrixCoord offset( - inner_idx * 4 * kElementsPerAccess, - inst_n_idx * 8); - - MatrixCoord access_coord = origin_ + offset; - - if (divisible_ || - (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { - - access_ptr[access_idx] = *reinterpret_cast( - ref_.data() + ref_.offset(offset)); - } - else { - AccessType zero; - zero.clear(); - access_ptr[access_idx] = zero; - } - } - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - - load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); - - load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation - } -}; - -/// Wrapper for ColumnMajor -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::ColumnMajor, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::ColumnMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIteratorCanonical< - Shape, kOperand, Element, - layout::ColumnMajor, - InstructionShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - TensorCoord const & extent, - int lane_id - ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - - -/// Wrapper for RowMajor -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Identifies A or B multiplicand - Operand Operand_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Interval between adjacent *MMA instructions (in units of MMA - /// instructions) - int OpDelta_, - /// Number of partitions along K dimension - int PartitionsK_> -class MmaTensorOpMultiplicandTileIterator< - Shape_, Operand_, Element_, - cutlass::layout::RowMajor, - InstructionShape_, OpDelta_, 32, PartitionsK_> { - public: - - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand_; - - static_assert(kOperand == Operand::kA || kOperand== Operand::kB, - "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Underlying tile iterator implementation - using Base = MmaTensorOpMultiplicandTileIteratorCanonical< - Shape, kOperand, Element, - layout::RowMajor, - InstructionShape, - kOpDelta, kThreads, PartitionsK_>; - - public: - - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - -private: - - /// Underlying tile iterator - Base iterator_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): iterator_({ref.data(), ref.stride()}, lane_id) { - } - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator( - TensorRef const &ref, - TensorCoord const &extent, - int lane_id - ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator++() { - - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpMultiplicandTileIterator & operator--() { - - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - iterator_.load(frag); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, - {tile_offset.contiguous(), tile_offset.strided()}, - byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h deleted file mode 100644 index 97f7e14f940ff29ff257ba18d2dfa6f5e844ea25..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h +++ /dev/null @@ -1,380 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations - targeting Sparse Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class SparseMmaTensorOpMetaTileIterator { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: - /// MatrixShape) - static int const kOpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - static int const kSparse = 2; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - struct Policy { - static_assert( - !(Shape::kColumn % InstructionShape::kColumn), - "Shape of warp-level Mma must be divisible by operator shape."); - - static int const kElementsPerAccess = 128 / sizeof_bits::value; - - // Determine number of elements along outer dimension per individual LDSM op - static int const kLdsmOpOuter = InstructionShape::kColumn; - static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter; - - static_assert(!(Shape::kColumn % kLdsmOpOuter), - "Shape of warp-level mma must be divisible by LDSM's " - "fundamental tile size."); - - static_assert(!(Shape::kRow % kLdsmOpInner), - "Shape of warp-level mma must be divisible by LDSM's " - "fundamental tile size."); - - /// Shape of one individual LDSM instruction - static int const LdsmShapeColumn = - InstructionShape::kColumn / kLdsmOpOuter; - static int const LdsmShapeRow = - ((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow) - ? (Shape::kRow / kLdsmOpInner) - : (4 / LdsmShapeColumn); - using LdsmShape = - layout::PitchLinearShape; - - /// Number and arrangement of LDSM instructions - using LdsmIterations = layout::PitchLinearShape< - Shape::kRow / kLdsmOpInner / LdsmShapeRow, - 1>; - - /// Number of groups for each tile - static int const kGroupsPerTile = - Shape::kColumn / InstructionShape::kColumn; - }; - - private: - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - /// Pointer type used for accesses - using AccessType = Array; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = - Array; - - private: - - /// Layout object storing stride values - Index stride_; - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Internal counter used to determine when to increment byte offset and when - /// to XOR it - int k_group_idx_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - SparseMmaTensorOpMetaTileIterator() - : pointer_(nullptr), - stride_(0), - byte_offset_(0), - k_group_idx_(0) {} - - /// Constructor from TensorRef - CUTLASS_DEVICE - SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id) - : pointer_(reinterpret_cast(ref.data())), - stride_(ref.stride(0) / Policy::kElementsPerAccess), - byte_offset_(0), - k_group_idx_(0) { - - int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess)); - int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess)); - - byte_offset_ = (access_contiguous + access_strided * stride_) * - sizeof_bits::value * Policy::kElementsPerAccess / 8; - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += offset * sizeof_bits::value / 8; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - SparseMmaTensorOpMetaTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - int offset = tile_offset.row() * Shape::kRow + - tile_offset.column() * InstructionShape::kColumn * stride_ * - Policy::kElementsPerAccess; - - add_pointer_offset(offset); - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - SparseMmaTensorOpMetaTileIterator &operator++() { - add_tile_offset({0, 1}); - - if (kPartitionsK > 1) { - ++k_group_idx_; - // Jump to next stage - if (k_group_idx_ == Policy::kGroupsPerTile) { - k_group_idx_ = 0; - add_tile_offset( - {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); - } - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - SparseMmaTensorOpMetaTileIterator &operator--(){ - byte_offset_ -= stride_ * InstructionShape::kColumn * - sizeof_bits::value * Policy::kElementsPerAccess / - 8; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator & - operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - SparseMmaTensorOpMetaTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_ + - Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c + - Policy::LdsmShape::kStrided * s * stride_; - - char const *source_byte_ptr = reinterpret_cast(source_ptr) + - byte_offset + byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], source_byte_ptr); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = - tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess + - tile_offset.strided() * InstructionShape::kColumn * stride_; - - byte_offset += sizeof(AccessType) * pointer_offset; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no op - } -}; - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h deleted file mode 100644 index 92e065f236fe8d62068487abb266a0e9c77fe712..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h +++ /dev/null @@ -1,805 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - - -#include "cutlass/cutlass.h" -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include "cutlass/wmma_array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// -template < - ///< Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity (A or B) - Operand Operand, - /// Data type of operand - typename Element_, - /// Layout of operand - typename Layout_, - /// Delta between *MMA operations (in units of *WMMA operations, concept:MatrixShape) - int OpDelta_, - /// Number of threads participating in one matrix operation - int Threads, - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - typename Policy_> -class MmaTensorOpWmmaMultiplicandTileIterator; - - -//////////////////////////////////////////////////////////////////////////////// -/// This tile iterator is specialized for 32-thread WMMA operation. -/// It uses nvcuda::wmma::load_matrix_sync to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -//////////////////////////////////////////////////////////////////////////////// -template < - ///< Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) - int OpDelta_, - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - typename Policy_> -class MmaTensorOpWmmaMultiplicandTileIterator< - Shape_, Operand::kA, Element_, Layout_, - OpDelta_, 32, Policy_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kA; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Delta between *WMMA operations - static int const kOpDelta = OpDelta_; - - /// Wmma Operator information and operation delta - using Policy = Policy_; - - - // - // Derived quantities - // - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Stride Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Native Wmma shape for operand A (concept MatrixShape) - using WmmaShape = MatrixShape< - Policy::Operator::Shape::kM, - Policy::Operator::Shape::kK - >; - - /// Map cutlass dataype to nvcuda::wmma datatype - using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; - - /// Shape of individual WMMA load / stores for operand A - using Iterations = MatrixShape< - Shape::kRow / WmmaShape::kRow, - 1 - >; - - /// Fragment object holding a warps part - using Fragment = WmmaFragmentArray; - - - ////////////////////////////////////////////////////////////////////////////////////////////////////// - /// statically assert this specialization - ///////////////////////////////////////////////////////////////////////////////////////////////////// - /// This iterator is specalized for Operand A - static_assert(kOperand == Operand::kA, - "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma."); - - /// Supported memory layouts - static_assert( - platform::is_same::value || - platform::is_same::value, - "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - -private: - - /// Shared memory base pointers - not advanced - char const *pointer_; - - /// Byte offset into shared memory - advanced - Index byte_offset_; - - /// Stride in units of number of elements - StrideIndex stride_; - - /// Layout of shared memory - Layout layout_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { - - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += (offset * sizeof_bits::value) / 8; - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn}); - - byte_offset_ += (elements_offset * sizeof_bits::value) / 8; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator++() { - - Index elements_offset = layout_({0, WmmaShape::kColumn}); - - byte_offset_ += (elements_offset * sizeof_bits::value) / 8; - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator--() { - - Index elements_offset = layout_({0, WmmaShape::kColumn}); - - byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load_with_byte_offset(Fragment &frag, Index byte_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - - Index load_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; - - const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); - - nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_); - - } - } - } - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_byte_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kColumn; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - - Index store_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; - - WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); - - nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_); - - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_byte_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -/// This tile iterator is specialized for 32-thread WMMA operation. -/// It uses nvcuda::wmma::load_matrix_sync to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -//////////////////////////////////////////////////////////////////////////////// - -template < - ///< Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) - int OpDelta_, - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - typename Policy_> -class MmaTensorOpWmmaMultiplicandTileIterator< - Shape_, Operand::kB, Element_, Layout_, - OpDelta_, 32, Policy_> { - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Operand tag - static Operand const kOperand = Operand::kB; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Delta between *WMMA operations - static int const kOpDelta = OpDelta_; - - /// Wmma Operator information and operation delta - using Policy = Policy_; - - - // - // Derived quantities - // - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Stride Index type - using StrideIndex = typename TensorRef::Layout::Stride::Index; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Native Wmma shape (concept MatrixShape) - using WmmaShape = MatrixShape< - Policy::Operator::Shape::kK, - Policy::Operator::Shape::kN - >; - - /// Map cutlass dataype to nvcuda::wmma datatype - using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; - - /// Shape of individual WMMA load / stores for operand B - using Iterations = MatrixShape< - 1, - Shape::kColumn / WmmaShape::kColumn - >; - - /// Fragment object holding a warps part - using Fragment = WmmaFragmentArray; - - - ////////////////////////////////////////////////////////////////////////////////////////////////////// - /// statically asserts this specialization - ///////////////////////////////////////////////////////////////////////////////////////////////////// - /// This iterator is specalized for Operand B - static_assert(kOperand == Operand::kB, - "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma."); - - /// Supported memory layouts - static_assert( - platform::is_same::value || - platform::is_same::value, - "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); - - /// Not working on this feature at the moment. - static_assert(kOpDelta == 1, - "Alternative arrangements not supported at present."); - - ///////////////////////////////////////////////////////////////////////////////////////////////////// - -private: - - /// Shared memory base pointers - not advanced - char const *pointer_; - - /// Byte offset into shared memory - advanced - Index byte_offset_; - - /// Stride in units of number of elements - StrideIndex stride_; - - /// Layout of shared memory - Layout layout_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator( - TensorRef const &ref, - int lane_id - ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { - - byte_offset_ += (offset * sizeof_bits::value) / 8; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - - Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn}); - - byte_offset_ += (elements_offset * sizeof_bits::value) / 8; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator++() { - - Index elements_offset = layout_({WmmaShape::kRow, 0}); - - byte_offset_ += (elements_offset * sizeof_bits::value) / 8; - - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator--() { - - Index elements_offset = layout_({WmmaShape::kRow, 0}); - - byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load_with_byte_offset(Fragment &frag, Index byte_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - Index load_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; - - const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); - - nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_); - } - } - } - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_byte_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Iterations::kRow; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - Index store_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; - - WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); - - nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_); - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_byte_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - -//////////////////////////////////////////////////////////////////////////////// -template < - ///< Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Interval between adjacent *WMMA instructions (in units of WMMA instructions, concept: MatrixShape) - typename OpDelta_, - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - typename Policy_> -class MmaTensorOpWmmaAccumulatorTileIterator; - -//////////////////////////////////////////////////////////////////////////////// -/// This tile iterator is specialized for 32-thread WMMA operation. -/// It uses nvcuda::wmma::store_matrix_sync to load from shared -/// memory and therefore must be initialized with a TensorRef to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept | -/// WriteableRandomAccessContiguousTileIteratorConcept -/// -//////////////////////////////////////////////////////////////////////////////// - -template < - ///< Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) - typename OpDelta_, - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - typename Policy_> -class MmaTensorOpWmmaAccumulatorTileIterator -{ - public: - - /// Shape of tile to load (concept: MatrixShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = Layout_; - - /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) - using OpDelta = OpDelta_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Wmma Operator information and operation delta - using Policy = Policy_; - - - // - // Derived quantities - // - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Native Wmma shape (concept MatrixShape) - using WmmaShape = MatrixShape< - Policy::Operator::Shape::kM, - Policy::Operator::Shape::kN - >; - - /// Map cutlass dataype to nvcuda::wmma datatype - using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; - - /// Map cutlass::layout to nvuda::wmma::layout_t enum - static nvcuda::wmma::layout_t const WmmaLayout = cutlass::arch::CutlassToWmmaLayout::value; - - /// Shape of individual WMMA load / stores for accumulator - using Iterations = MatrixShape< - Shape::kRow / WmmaShape::kRow, - Shape::kColumn / WmmaShape::kColumn - >; - - /// Fragment object holding a thread's part of a tile - using Fragment = WmmaFragmentArray; - - ////////////////////////////////////////////////////////////////////////////////////////////////////// - /// statically asserts this specialization - ///////////////////////////////////////////////////////////////////////////////////////////////////// - /// Supported layouts - static_assert( - platform::is_same::value || - platform::is_same::value, - "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); - -private: - - /// Internal reference - cutlass::TensorRef ref_; - -public: - - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator() { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator( - TensorRef const &ref, - int lane_id - ): ref_(ref) { } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { - ref_.add_pointer_offset(offset); - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { - ref_.add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn}); - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator & operator++() { - ref_.add_coord_offset({Shape::kRow, 0}); - return *this; - } - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator & operator--() { - ref_.add_coord_offset({-Shape::kRow, 0}); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - MmaTensorOpWmmaAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - const WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); - - nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.stride()[0], WmmaLayout); - - } - } - } - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Iterations::kRow; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Iterations::kColumn; ++n) { - - WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); - - nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.stride()[0], WmmaLayout); - } - } - } - - /// Stores a fragment to memory at the location pointed to by the iterator - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) const { - store_with_pointer_offset(frag, 0); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - // no operation here - } -}; - - - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// - -#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h deleted file mode 100644 index ec445443afd504a201b6788133099015dd52e7a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h +++ /dev/null @@ -1,223 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include "cutlass/wmma_array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///< Structure to compute the matrix product targeting CUDA cores via WMMA. -template < - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - ///< Data type of A elements - typename ElementA_, - ///< Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - ///< Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - ///< Element type of C matrix - typename ElementC_, - ///< Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - ///< Policy describing warp-level Wmma operation (concept: MmaTensorOpPolicy) - typename Policy_, - ///< Number of partitions along K dimension - int PartitionsK_ = 1, - ///< Used for partial specialization - typename Enable = bool -> -class MmaTensorOpWmma { -public: - ///< Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - ///< Data type of multiplicand A - using ElementA = ElementA_; - - ///< Layout of multiplicand A - using LayoutA = LayoutA_; - - ///< Data type of multiplicand B - using ElementB = ElementB_; - - ///< Layout of multiplicand B - using LayoutB = LayoutB_; - - ///< Data type of accumulator matrix C - using ElementC = ElementC_; - - ///< Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) - using Policy = Policy_; - - /// Underlying instruction shape - using InstructionShape = typename Policy::Operator::Shape; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Underlying architecture tag - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassWmmaTensorOp; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - Policy::OpDelta::kRow, kThreadCount, Policy>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - Policy::OpDelta::kRow, kThreadCount, Policy>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename Policy::OpDelta, Policy>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - -private: - - static_assert( - !(Shape::kM % Policy::Operator::Shape::kM) && - !(Shape::kN % Policy::Operator::Shape::kN), - "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)"); - - /// Number of wmma operations performed - using WmmaIterations = MatrixShape< - Shape::kM / Policy::Operator::Shape::kM, - Shape::kN / Policy::Operator::Shape::kN - >; - -public: - - /// Underlying matrix multiply operator (concept: cutlass::arch::Wmma) - typename Policy::Operator wmma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpWmma() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < WmmaIterations::kColumn; ++n) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < WmmaIterations::kRow; ++m) { - - // accumulate wmma mma - wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]); - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h deleted file mode 100644 index d97c8f449f84e1cc3b08977b109aeda7c827d89f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h +++ /dev/null @@ -1,449 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Reduce operand A or B along K dimension - bool ReduceKForA_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> -class MmaWithReductionTensorOp { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - static bool const kReduceKForA = ReduceKForA_; - - static_assert(platform::is_same::value || - platform::is_same::value, - "ElementA needs to be fp16 or bf16."); - - static_assert(platform::is_same::value || - platform::is_same::value, - "ElementB needs to be fp16 or bf16."); - - static_assert(platform::is_same>::value, - "Only supports 16x8x16 tensor core instruction."); - - static_assert(!AccumulatorsInRowMajor, - "Only calls tensor core instructions in column major."); - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = - Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = - Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - - using FragmentReduction = Array; - -public: - - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaWithReductionTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C, - FragmentReduction &gemm_k_reduction - ) const { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - D = C; - - [[maybe_unused]] MmaOperandA const *ptr_A = reinterpret_cast(&A); - [[maybe_unused]] MmaOperandB const *ptr_B = reinterpret_cast(&B); - [[maybe_unused]] MmaOperandC *ptr_D = reinterpret_cast(&D); - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - assert(0); - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - - if (!kReduceKForA && m == 0) { - #if 0 - gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); - gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); - gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); - gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); - #else - uint32_t const *tmp = reinterpret_cast(&B); - - if (platform::is_same::value) { - asm volatile( - "{\n\t" - " .reg .f16 low, high;\n\t" - " .reg .f32 tmp;\n\t" - " mov.b32 {low, high}, %1;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " mov.b32 {low, high}, %2;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - "}\n\t" - : "+f"(gemm_k_reduction[n_serpentine]) - : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); - } else if (platform::is_same::value) { - asm volatile( - "{\n\t" - " .reg .f32 tmp;\n\t" - " shl.b32 tmp, %1, 16;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " and.b32 tmp, %1, 0xffff0000;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " shl.b32 tmp, %2, 16;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " and.b32 tmp, %2, 0xffff0000;\n\t" - " add.f32 %0, tmp, %0;\n\t" - "}\n\t" - : "+f"(gemm_k_reduction[n_serpentine]) - : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); - } else { - assert(0); - } - #endif - } - - if (kReduceKForA && (n == 0)) { - #if 0 - gemm_k_reduction[m * 2] += float(A[m * 8]); - gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); - gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); - gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); - - gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); - gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); - gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); - gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); - #else - uint32_t const *tmp = reinterpret_cast(&A); - - if (platform::is_same::value) { - asm volatile( - "{\n\t" - " .reg .f16 low, high;\n\t" - " .reg .f32 tmp;\n\t" - " mov.b32 {low, high}, %2;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " mov.b32 {low, high}, %3;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " mov.b32 {low, high}, %4;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " mov.b32 {low, high}, %5;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %1, tmp, %1;\n\t" - "}\n\t" - : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) - : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); - - } else if (platform::is_same::value) { - - asm volatile( - "{\n\t" - " .reg .f32 tmp;\n\t" - " shl.b32 tmp, %2, 16;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " and.b32 tmp, %2, 0xffff0000;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " shl.b32 tmp, %3, 16;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " and.b32 tmp, %3, 0xffff0000;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " shl.b32 tmp, %4, 16;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " and.b32 tmp, %4, 0xffff0000;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " shl.b32 tmp, %5, 16;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " and.b32 tmp, %5, 0xffff0000;\n\t" - " add.f32 %1, tmp, %1;\n\t" - "}\n\t" - : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) - : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); - - } else { - assert(0); - } - #endif - } - } - } - #else - assert(0); - #endif - } - - /// Transform the mma operands to the required types - CUTLASS_DEVICE - void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, - FragmentA const &A, FragmentB const &B) const { - - // - // Define conversions from source type to instruction type - // - FloatRoundStyle const kRoundA = - PreferredRoundingMode::kRound; - FloatRoundStyle const kRoundB = - PreferredRoundingMode::kRound; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_B = - reinterpret_cast const *>(&B); - Array * - ptr_dst_B = reinterpret_cast *>(&dst_B); - - dst_A = convert_A(A); - - ptr_dst_B[0] = convert_B(ptr_B[0]); - ptr_dst_B[1] = convert_B(ptr_B[1]); - - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_A = - reinterpret_cast const *>(&A); - Array * - ptr_dst_A = reinterpret_cast *>(&dst_A); - - dst_B = convert_B(B); - - ptr_dst_A[0] = convert_A(ptr_A[0]); - ptr_dst_A[1] = convert_A(ptr_A[1]); - #else - assert(0); - #endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h deleted file mode 100644 index 2d79dcf7005a3940e6960d5e9b5c7ad87ea4ed9f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h +++ /dev/null @@ -1,572 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines iterators used by warp-level loading scale and bias vectors. - Every scale/bias data only needs to be loaded once for every channel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" - -#include "cutlass/platform/platform.h" -#include "cutlass/fast_math.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of A elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Policy of the details of LDSM shape and iterations - typename Policy_, - /// Number of threads participating in one matrix operation - int Threads, - /// Number of partitions along K dimension - int PartitionsK_ = 1> -class ScaleBiasTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: PitchLinearShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: PitchLinearShape) - typename InstructionShape_, - /// Policy of the details of LDSM shape and iterations - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_> -class ScaleBiasTileIterator { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::PitchLinear; - - /// Shape of one matrix product operation (concept: GemmShape) - using InstructionShape = InstructionShape_; - - /// Number of participating threads - static int const kThreads = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - - /// Number of partitions along K dimension - static int const kElementsPerAccess = 128 / sizeof_bits::value; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - using Policy = Policy_; - - private: - - /// Pointer type used for accesses - using AccessType = Array; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = Array; - - private: - - /// Shared memory base pointers - not advanced - AccessType const *pointer_; - - /// Byte offset incremented as iterator advances - Index byte_offset_; - - /// Internal counter used to determine when to increment byte offset and when - /// to XOR it - int k_group_idx_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator() - : pointer_(nullptr), - byte_offset_(0), - k_group_idx_(0) {} - - /// Constructor from TensorRef - CUTLASS_DEVICE - ScaleBiasTileIterator(TensorRef const &ref_scale_bias, - int lane_id) - : byte_offset_(0), k_group_idx_(0) { - /// 16816 only - pointer_ = reinterpret_cast(ref_scale_bias.data()) + - ((lane_id >> 3) & 1) * Shape::kContiguous / kElementsPerAccess + - (lane_id >> 4); - } - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { - byte_offset_ += offset * sizeof_bits::value / 8; - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - ScaleBiasTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; - int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; - - byte_offset_ += k_groups_delta * sizeof_bits::value * - kElementsPerAccess * Policy::LdsmShape::kContiguous / 8; - - // Multiply by 2 because scale and bias belonging to the same stage are next - // to each other in the shared memory. - pointer_ += (2 * whole_tiles * Shape::kContiguous / kElementsPerAccess); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - ScaleBiasTileIterator &operator++() { - byte_offset_ += Policy::LdsmShape::kContiguous * - sizeof_bits::value * kElementsPerAccess / 8; - - k_group_idx_++; - - if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { - k_group_idx_ = 0; - byte_offset_ -= (Policy::kGroupsPerTile / kPartitionsK) * - Policy::LdsmShape::kContiguous * - sizeof_bits::value * kElementsPerAccess / 8; - add_tile_offset({Policy::kGroupsPerTile, 0}); - } - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator &operator--() { assert(0); } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - ScaleBiasTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - ScaleBiasTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - Array *fetch_ptr = - reinterpret_cast *>(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < 1; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { - int access_idx = c + s * Policy::LdsmIterations::kContiguous; - - AccessType const *source_ptr = - pointer_ + Policy::LdsmShape::kContiguous * c; - - char const *source_byte_ptr = - reinterpret_cast(source_ptr) + byte_offset + - byte_offset_; - - cutlass::arch::ldsm( - fetch_ptr[access_idx], source_byte_ptr); - } - } - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - load_with_byte_offset(frag, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - load_with_byte_offset(frag, tile_offset, 0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - Index pointer_offset = tile_offset.contiguous() * - InstructionShape::kContiguous / - kElementsPerAccess; - - byte_offset += sizeof_bits::value * pointer_offset / 8; - - load_with_byte_offset(frag, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -/// load from shared memory and therefore must be initialized with a TensorRef -/// to shared memory. -/// -/// Satisfies: -/// ReadableRandomAccessContiguousTileIteratorConcept -/// -template < - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Data type of elements - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - /// Policy of the details of LDSM shape and iterations - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_> -class ScaleBiasTileIterator { - public: - /// Shape of tile to load (concept: PitchLinearShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Number of participating threads - static int const kThreads = 32; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Internal structure of iterator - made public to enable introspection - using Policy = Policy_; - - /// Underlying tile iterator implementation - using Base = ScaleBiasTileIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, - layout::PitchLinearShape, - Policy, kThreads, PartitionsK_>; - - public: - // - // Derived quantities - // - - /// Fragment object holding a thread's part of a tile - using Fragment = typename Base::Fragment; - - private: - /// Underlying tile iterator - Base iterator_; - - public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator() {} - - /// Constructor from TensorRef - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator(TensorRef const &ref_scale_bias, int lane_id) - : iterator_({ref_scale_bias.data(), ref_scale_bias.stride()}, lane_id) {} - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { - iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_DEVICE - ScaleBiasTileIterator &add_tile_offset_negative( - TensorCoord const &tile_offset) { - iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator &operator++() { - ++iterator_; - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_HOST_DEVICE - ScaleBiasTileIterator &operator--() { - --iterator_; - - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - ScaleBiasTileIterator &operator+=( - TensorCoord const &tile_offset) { - add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of - ///< the tensor - CUTLASS_DEVICE - ScaleBiasTileIterator &operator-=( - TensorCoord const &tile_offset) { - add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { iterator_.load(frag); } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index byte_offset) const { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - assert(0); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - iterator_.load_with_byte_offset( - frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - iterator_.set_kgroup_index(k_group); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h deleted file mode 100644 index 7e3af9bff42a8895c7fb1e55a873b74e2a7ba249..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h +++ /dev/null @@ -1,117 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level per-channel softmax before - matrix multiply-accumulate operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SoftmaxScaleBiasTransform { - - using T = typename FragmentActivations::Element; - - static int const NumActivations = FragmentActivations::kElements; - static int const NumNormSum = FragmentNormSum::kElements; - static int const MmaElements = 2; - // One element has one scale and one bias - static int const MmaScaleBiasPair = 2; - // 16816 has 2 columns and 2 rows - static int const MmaCols = 2; - static int const MmaRows = 2; - - using MmaOperand = Array; - using NormSumOperand = Array<__half2, MmaScaleBiasPair>; - - CUTLASS_DEVICE - void transform(MmaOperand &activations, - NormSumOperand const &norm_sum) { - - __half2* packed_activations = reinterpret_cast<__half2*>(&activations); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < MmaElements / 2; ++i) { - __half2 out = ::h2exp(__hsub2(packed_activations[i], norm_sum[2*i])); - packed_activations[i] = __hmul2(out, norm_sum[2*i + 1]); - } - } - - CUTLASS_DEVICE - void operator()(FragmentActivations &activations, - FragmentNormSum const &norm_sum) { - MmaOperand *ptr_activations = reinterpret_cast(&activations); - NormSumOperand const *ptr_norm_sum = - reinterpret_cast(&norm_sum); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < (NumActivations / MmaElements); ++i) { - transform(ptr_activations[i], - ptr_norm_sum[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h deleted file mode 100644 index 0406db0ddff902995a92b5c11d4c5e5024334e4c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h +++ /dev/null @@ -1,250 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/array_planar_complex.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class TileIteratorPlanarComplex { -public: - - /// Underlying iterator over real-valued tiles - using TileIterator = TileIterator_; - - /// Underlying element type - using Element = typename TileIterator::Element; - - /// Underlying layout type - using Layout = typename TileIterator::Layout; - - /// TensorRef type for loading element from a tensor - using TensorRef = typename TileIterator::TensorRef; - - /// Index type - using Index = typename TensorRef::Index; - - /// Long Index type - using LongIndex = typename TensorRef::LongIndex; - - /// Coordinate for an element in the tensor - using TensorCoord = typename TensorRef::TensorCoord; - - /// Planar complex fragment - using Fragment = ArrayPlanarComplex; - -public: - - /// Underlying tile iterator - TileIterator tile_iterator_; - - /// Offset (in units of bytes) to the imaginary part of the planar complex matrix - LongIndex imaginary_offset_; - -public: - /// Default ctor constructs null iterator - CUTLASS_HOST_DEVICE - TileIteratorPlanarComplex(): imaginary_offset_(0) { } - - /// Constructor from TensorRef - CUTLASS_DEVICE - TileIteratorPlanarComplex( - TensorRef const &ref, - int lane_id, - LongIndex imaginary_offset - ): - tile_iterator_(ref, lane_id), - imaginary_offset_((imaginary_offset * sizeof_bits::value) / 8) { } - - - /// Adds a pointer offset to internal pointer(s) to advance through memory - CUTLASS_DEVICE - TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) { - - tile_iterator_.add_pointer_offset(offset); - - return *this; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_HOST_DEVICE - TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) { - - tile_iterator_.add_tile_offset(tile_offset); - - return *this; - } - - /// Advances the iterator along the advance dimension - CUTLASS_DEVICE - TileIteratorPlanarComplex & operator++() { - ++tile_iterator_; - return *this; - } - - // - // WIP - // - - /// Advances the iterator along the opposite of the advance dimension - CUTLASS_HOST_DEVICE - TileIteratorPlanarComplex & operator--() { - --tile_iterator_; - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) { - tile_iterator_.add_tile_offset(tile_offset); - return *this; - } - - ///< advances in units of whole tiles along the logical coordinate space of the tensor - CUTLASS_DEVICE - TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) { - tile_iterator_.add_tile_offset(-tile_offset); - return *this; - } - - /// Loads a fragment from memory at the location pointed to by the iterator. - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - tile_iterator_.load_with_byte_offset(frag.real, 0); - tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset in units of bytes - Index byte_offset) const { - - tile_iterator_.load_with_byte_offset(frag.real, byte_offset); - tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); - } - - /// Loads a fragment from memory with additional logical offset - CUTLASS_DEVICE - void load_with_pointer_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a linear offset - Index pointer_offset) const { - - Index byte_offset = (pointer_offset * sizeof_bits::value)/8; - - tile_iterator_.load_with_byte_offset(frag.real, byte_offset); - tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset) const { - - tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0); - tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index pointer_offset) const { - - Index byte_offset = (pointer_offset * sizeof_bits::value)/8; - - tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); - tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_); - } - - /// Loads a fragment from memory with logical offset in units of whole tiles. - CUTLASS_DEVICE - void load_with_byte_offset( - /// fragment to load from the tensor - Fragment &frag, - /// loads a tile with a logical offset in units of whole tiles - TensorCoord const &tile_offset, - /// loads a tile with a logical offset AND a pointer offset - Index byte_offset) const { - - tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); - tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_); - } - - /// Notify the iterator which k-group it is currently pointing to. - /// - /// This does not advance the iterator. Rather, it overrides its internal - /// tracking with constant-valued k-group index to enable the compiler to - /// fold constants and achieve more efficient code. - /// - /// This is used by some nontrivial permuted layouts. - CUTLASS_DEVICE - void set_kgroup_index(int k_group) { - tile_iterator_.set_kgroup_index(k_group); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.h deleted file mode 100644 index dd826de23c463d021d5c0abb50867faebbdc9b47..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.h +++ /dev/null @@ -1,394 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/coord.h" - -namespace cutlass { -namespace gemm { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Shape of a matrix multiply-add operation -template < - /// Rows of matrix product - int M = 1, - /// Columns of matrix product - int N = 1, - /// Inner dimension of matrix product - int K = 1 -> -struct GemmShape { - static int const kM = M; - static int const kN = N; - static int const kK = K; - - static int const kMN = M * N; - static int const kMK = M * K; - static int const kKN = N * K; - static int const kMNK = M * N * K; - - static int const kCount = kMNK; - - // - // Static member functions - // - - /// Returns a Coord object - CUTLASS_HOST_DEVICE - static Coord<3> toCoord() { - return make_Coord(kM, kN, kK); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Type alias of the transpose of a GemmShape -template < - /// concept: GemmShape - typename Shape -> -using GemmShapeTranspose = GemmShape; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// GemmCoord is a structure derived from Coord<3> that specifies a location within the -/// coordinate space of a GEMM problem. -struct GemmCoord : public Coord<3, int> { - - /// Integer-valued index - typedef int Index; - - /// Base type is a Coord of rank=3 - typedef Coord<3, Index> Base; - - /// GEMM M dimension - rows of the output C matrix - static int const kM = 0; - - /// GEMM N dimension - columns of the output C matrix - static int const kN = 1; - - /// GEMM K dimension - inner dimension of the GEMM problem - static int const kK = 2; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - GemmCoord() { } - - /// Constructs from Coord<3> and a batch - CUTLASS_HOST_DEVICE - GemmCoord(Coord<3, Index> const& coord): Base(make_Coord(coord[0], coord[1], coord[2])) { } - - /// Helper to construct from a K, N, M, batch variables - CUTLASS_HOST_DEVICE - GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { } - - /// Returns the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index const& m() const { return this->at(kM); } - - /// Returns reference to the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index & m() { return this->at(kM); } - - /// Returns the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index const& n() const { return this->at(kN); } - - /// Returns reference to the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index const& k() const { return this->at(kK); } - - /// Returns reference to the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index & k() { return this->at(kK); } - - /// Obtains a Coord<3> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<3> mnk() const { - return make_Coord(m(), n(), k()); - } - - /// Obtains a Coord<3> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<3> knm() const { - return make_Coord(k(), n(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> nm() const { - return make_Coord(n(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> mn() const { - return make_Coord(m(), n()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> mk() const { - return make_Coord(m(), k()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> km() const { - return make_Coord(k(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> nk() const { - return make_Coord(n(), k()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> kn() const { - return make_Coord(k(), n()); - } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - GemmCoord operator+(Base const& b) const { - return GemmCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - GemmCoord operator-(Base const& b) const { - return GemmCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - GemmCoord operator*(Base const& b) const { - return GemmCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - GemmCoord operator/(Base const& b) const { - return GemmCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - GemmCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - GemmCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - GemmCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - GemmCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// BatchedGemmCoord is a structure derived from Coord<4> that specifies a location within the -/// coordinate space of a batched GEMM problem. -struct BatchedGemmCoord : public Coord<4, int> { - - /// Integer-valued index - typedef int Index; - - /// Base type is a Coord of rank=4 - typedef Coord<4, Index> Base; - - /// GEMM M dimension - rows of the output C matrix - static int const kM = 0; - - /// GEMM N dimension - columns of the output C matrix - static int const kN = 1; - - /// GEMM K dimension - inner dimension of the GEMM problem - static int const kK = 2; - - /// GEMM Batch dimension - inner dimension of the GEMM problem - static int const kBatch = 3; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - BatchedGemmCoord() { } - - /// Constructs from Coord<4> - CUTLASS_HOST_DEVICE - BatchedGemmCoord(Base const& coord): Base(coord) { } - - /// Helper to construct from a K, N, M, and batch variables - CUTLASS_HOST_DEVICE - BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { } - - /// Returns the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index const& m() const { return this->at(kM); } - - /// Returns reference to the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index & m() { return this->at(kM); } - - /// Returns the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index const& n() const { return this->at(kN); } - - /// Returns reference to the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index const& k() const { return this->at(kK); } - - /// Returns reference to the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index & k() { return this->at(kK); } - - /// Returns the GEMM batch coordinate - CUTLASS_HOST_DEVICE - Index const& batch() const { return this->at(kBatch); } - - /// Returns reference to the GEMM batch coordinate - CUTLASS_HOST_DEVICE - Index & batch() { return this->at(kBatch); } - - /// Obtains a GemmCoord from BatchedGemmCoord - CUTLASS_HOST_DEVICE - GemmCoord mnk() const { - return GemmCoord(m(), n(), k()); - } - - /// Obtains a Coord<4> from BatchedGemmCoord - CUTLASS_HOST_DEVICE - Coord<4> mnkb() const { - return make_Coord(m(), n(), k(), batch()); - } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator+(Base const& b) const { - return BatchedGemmCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator-(Base const& b) const { - return BatchedGemmCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator*(Base const& b) const { - return BatchedGemmCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator/(Base const& b) const { - return BatchedGemmCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.hpp deleted file mode 100644 index a22b8031d186f25e58cd96df6c75606454d50d0f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm_coord.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Utilities to convert a CuTe tuple to a GemmCoord or BatchedGemmCoord -*/ - -#pragma once - -#include "cute/layout.hpp" -#include "cutlass/gemm_coord.h" - -namespace cutlass { -namespace gemm { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE -auto -to_gemm_coord(Tuple tuple) { - static_assert(cute::rank(tuple) <= 4, "Can only convert tuples of rank <= 4."); - - if constexpr (cute::rank(tuple) <= 3) { - auto tuple_mnk = cute::append<3>(tuple, cute::Int<0>{}); - return GemmCoord(cute::size<0>(tuple_mnk), cute::size<1>(tuple_mnk), cute::size<2>(tuple_mnk)); - } - else { - return BatchedGemmCoord(cute::size<0>(tuple), cute::size<1>(tuple), cute::size<2>(tuple), cute::size<3>(tuple)); - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/half.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/half.h deleted file mode 100644 index 118a80d7045dddd4239fc7f0756dc445fa9a2895..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/half.h +++ /dev/null @@ -1,930 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines a class for using IEEE half-precision floating-point types in host or - device code. -*/ - -#pragma once - -#ifndef CUTLASS_ENABLE_F16C -#define CUTLASS_ENABLE_F16C 0 -#endif - -#if defined(__CUDACC_RTC__) - -#include "cutlass/floating_point_nvrtc.h" - -// F16C extensions are not meaningful when compiling for NVRTC which only accommodates device code. -#undef CUTLASS_ENABLE_F16C -#define CUTLASS_ENABLE_F16C 0 - -#else -// -// Standard Library headers belong here to avoid conflicts with NVRTC. -// -#include -#include -#include -#include -#endif - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/platform/platform.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Optionally target F16C extensions to accelerate half-precision conversion. -#if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) -#if defined(_MSC_VER) - -#include - -#if defined(__i386__) || defined(__x86_64__) -#include -#endif - -#define F16C_ROUND_NEAREST 0 - -#if !defined(__CUDA_ARCH__) -extern __inline float _cvtsh_ss (unsigned short __S) { - __m128i packed; - std::memcpy(&packed, &__S, sizeof(__S)); - - __m128 result = _mm_cvtph_ps(packed); - - float flt; - std::memcpy(&flt, &result, sizeof(flt)); - - return flt; -} - -__inline unsigned short _cvtss_sh (float __F, const int) { - __m128 packed; - std::memcpy(&packed, &__F, sizeof(__F)); - - __m128i result = _mm_cvtps_ph(packed, F16C_ROUND_NEAREST); - - unsigned short u; - std::memcpy(&u, &result, sizeof(u)); - - return u; -} -#endif - -#else - -// Linux -#include - -#if defined(__i386__) || defined(__x86_64__) -#include -#endif - -#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) - -#endif // _MSC_VER - -class CpuId { - - bool f16c_enabled; - - CpuId() { - #if defined(__i386__) || defined(__x86_64__) - #if defined(_MSC_VER) - int exx[4]; - - __cpuid (exx, 1); - f16c_enabled = exx[2] & 0x20000000; - - #else - // GCC / Clang - int eax, ebx, ecx, edx; - - __cpuid (1 , eax, ebx, ecx, edx); - f16c_enabled = ecx & 0x20000000; - #endif - #else - // Arm / PowerPC etc. - f16c_enabled = false; - #endif - } - -public: - - bool is_f16c_supported() const { - return f16c_enabled; - } - - static const CpuId& instance() { - static CpuId cpu; - return cpu; - } -}; -#endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// IEEE half-precision floating-point type -struct alignas(2) half_t { - - // - // Data members - // - - /// Storage type - uint16_t storage; - - // - // Static conversion operators - // - - /// Constructs from an unsigned short - CUTLASS_HOST_DEVICE - static half_t bitcast(uint16_t x) { - half_t h; - h.storage = x; - return h; - } - - /// FP32 -> FP16 conversion - rounds to nearest even - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) - // Avoid inlining in device code if no hardware support - __device__ __noinline__ - #else - CUTLASS_HOST_DEVICE - #endif - static half_t convert(float const& flt) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__float2half_rn(flt)); - #else - - #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - if( CpuId::instance().is_f16c_supported() ) { - unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); - return bitcast(u); - } - #endif - - // software implementation rounds toward nearest even - unsigned s; - - #if defined(__CUDA_ARCH__) - s = reinterpret_cast(flt); - #else - std::memcpy(&s, &flt, sizeof(s)); - #endif - - uint16_t sign = uint16_t((s >> 16) & 0x8000); - int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); - int mantissa = s & 0x7fffff; - uint16_t u = 0; - - if ((s & 0x7fffffff) == 0) { - // sign-preserving zero - return bitcast(sign); - } - - if (exp > 15) { - if (exp == 128 && mantissa) { - // not a number - u = 0x7fff; - } else { - // overflow to infinity - u = sign | 0x7c00; - } - return bitcast(u); - } - - int sticky_bit = 0; - - if (exp >= -14) { - // normal fp32 to normal fp16 - exp = uint16_t(exp + uint16_t(15)); - u = uint16_t(((exp & 0x1f) << 10)); - u = uint16_t(u | (mantissa >> 13)); - } else { - // normal single-precision to subnormal half_t-precision representation - int rshift = (-14 - exp); - if (rshift < 32) { - mantissa |= (1 << 23); - - sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); - - mantissa = (mantissa >> rshift); - u = (uint16_t(mantissa >> 13) & 0x3ff); - } else { - mantissa = 0; - u = 0; - } - } - - // round to nearest even - int round_bit = ((mantissa >> 12) & 1); - sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0); - - if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { - u = uint16_t(u + 1); - } - - u |= sign; - - return bitcast(u); - #endif - } - - /// FP32 -> FP16 conversion - rounds to nearest even - CUTLASS_HOST_DEVICE - static half_t convert(int const& n) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__int2half_rn(n)); - #else - return convert(float(n)); - #endif - } - - /// FP32 -> FP16 conversion - rounds to nearest even - CUTLASS_HOST_DEVICE - static half_t convert(unsigned const& n) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__uint2half_rn(n)); - #else - return convert(float(n)); - #endif - } - - /// Converts a half-precision value stored as a uint16_t to a float - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) - // Avoid inlining in device code if no hardware support - __device__ __noinline__ - #else - CUTLASS_HOST_DEVICE - #endif - static float convert(half_t const& x) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __half2float(x.to_half()); - #else - - #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - if( CpuId::instance().is_f16c_supported() ) { - unsigned short u = x.storage; - return _cvtsh_ss(u); - } - #endif - - uint16_t const &h = x.storage; - uint32_t sign = ((h >> 15) & 1); - uint32_t exp = ((h >> 10) & 0x1f); - uint32_t mantissa = (h & 0x3ff); - unsigned f = 0; - - if (exp > 0 && exp < 31) { - // normal - exp += 112; - f = (sign << 31) | (exp << 23) | (mantissa << 13); - } else if (exp == 0) { - if (mantissa) { - // subnormal - exp += 113; - while ((mantissa & (1 << 10)) == 0) { - mantissa <<= 1; - exp--; - } - mantissa &= 0x3ff; - f = (sign << 31) | (exp << 23) | (mantissa << 13); - } else { - // sign-preserving zero - f = (sign << 31); - } - } else if (exp == 31) { - if (mantissa) { - f = 0x7fffffff; // not a number - } else { - f = (0xff << 23) | (sign << 31); // inf - } - } - #if defined(__CUDA_ARCH__) - return reinterpret_cast(f); - #else - float flt; - std::memcpy(&flt, &f, sizeof(flt)); - return flt; - #endif - #endif - } - - // - // Methods - // - - /// Default constructor - half_t() = default; - - /// Reinterpret cast from CUDA's half type - CUTLASS_HOST_DEVICE - explicit half_t(half const & x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - __half_raw raw(x); - std::memcpy(&storage, &raw.x, sizeof(storage)); - #endif - } - - /// Floating point conversion - CUTLASS_HOST_DEVICE - explicit half_t(float x) { - storage = convert(x).storage; - } - - /// Floating point conversion - CUTLASS_HOST_DEVICE - explicit half_t(double x): half_t(float(x)) { - - } - - /// float_e4m3_t conversion - CUTLASS_HOST_DEVICE - explicit half_t(float_e4m3_t x): half_t(float(x)) { - - } - - /// float_e5m2_t conversion - CUTLASS_HOST_DEVICE - explicit half_t(float_e5m2_t x): half_t(float(x)) { - - } - - /// Integer conversion - round to nearest even - CUTLASS_HOST_DEVICE - explicit half_t(int x) { - storage = convert(x).storage; - } - - /// Integer conversion - round toward zero - CUTLASS_HOST_DEVICE - explicit half_t(unsigned x) { - storage = convert(x).storage; - } - - /// Assignment - CUTLASS_HOST_DEVICE - half_t & operator=(half const &x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - __half_raw raw(x); - std::memcpy(&storage, &raw.x, sizeof(storage)); - #endif - return *this; - } - - /// Converts to float - CUTLASS_HOST_DEVICE - operator float() const { - return convert(*this); - } - - /// Converts to float - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(convert(*this)); - } - - /// Converts to float - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(convert(*this)); - } - - /// Casts to bool - CUTLASS_HOST_DEVICE - explicit operator bool() const { - return (convert(*this) != 0.0f); - } - - /// Bitcasts to CUDA's half type - CUTLASS_HOST_DEVICE - half to_half() const { - #if defined(__CUDA_ARCH__) - return reinterpret_cast(storage); - #else - __half_raw raw; - std::memcpy(&raw.x, &storage, sizeof(raw.x)); - return half(raw); - #endif - } - - /// Accesses raw internal state - CUTLASS_HOST_DEVICE - uint16_t& raw() { - return storage; - } - - /// Accesses raw internal state - CUTLASS_HOST_DEVICE - uint16_t raw() const { - return storage; - } - - /// Returns the sign bit - CUTLASS_HOST_DEVICE - bool signbit() const { - return ((storage & 0x8000) != 0); - } - - /// Returns the biased exponent - CUTLASS_HOST_DEVICE - int exponent_biased() const { - return int((storage >> 10) & 0x1f); - } - - /// Returns the unbiased exponent - CUTLASS_HOST_DEVICE - int exponent() const { - return exponent_biased() - 15; - } - - /// Returns the mantissa - CUTLASS_HOST_DEVICE - int mantissa() const { - return int(storage & 0x3ff); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool signbit(cutlass::half_t const& h) { - return ((h.raw() & 0x8000) != 0); -} - -CUTLASS_HOST_DEVICE -cutlass::half_t abs(cutlass::half_t const& h) { - return cutlass::half_t::bitcast(h.raw() & 0x7fff); -} - -CUTLASS_HOST_DEVICE -bool isnan(cutlass::half_t const& h) { - return (h.exponent_biased() == 0x1f) && h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isfinite(cutlass::half_t const& h) { - return (h.exponent_biased() != 0x1f); -} - -CUTLASS_HOST_DEVICE -cutlass::half_t nanh(const char*) { - // NVIDIA canonical NaN - return cutlass::half_t::bitcast(0x7fff); -} - -CUTLASS_HOST_DEVICE -bool isinf(cutlass::half_t const& h) { - return (h.exponent_biased() == 0x1f) && !h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isnormal(cutlass::half_t const& h) { - return h.exponent_biased() && h.exponent_biased() != 0x1f; -} - -CUTLASS_HOST_DEVICE -int fpclassify(cutlass::half_t const& h) { - int exp = h.exponent_biased(); - int mantissa = h.mantissa(); - if (exp == 0x1f) { - if (mantissa) { - return FP_NAN; - } - else { - return FP_INFINITE; - } - } - else if (!exp) { - if (mantissa) { - return FP_SUBNORMAL; - } - else { - return FP_ZERO; - } - } - return FP_NORMAL; -} - -CUTLASS_HOST_DEVICE -cutlass::half_t sqrt(cutlass::half_t const& h) { -#if defined(__CUDACC_RTC__) - return cutlass::half_t(sqrtf(float(h))); -#else - return cutlass::half_t(std::sqrt(float(h))); -#endif -} - -CUTLASS_HOST_DEVICE -half_t copysign(half_t const& a, half_t const& b) { - - uint16_t a_mag = (a.raw() & 0x7fff); - uint16_t b_sign = (b.raw() & 0x8000); - uint16_t result = (a_mag | b_sign); - - return half_t::bitcast(result); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Standard Library operations and definitions -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#if !defined(__CUDACC_RTC__) -namespace std { - -/// Numeric limits -template <> -struct numeric_limits { - static bool const is_specialized = true; - static bool const is_signed = true; - static bool const is_integer = false; - static bool const is_exact = false; - static bool const has_infinity = true; - static bool const has_quiet_NaN = true; - static bool const has_signaling_NaN = false; - static std::float_denorm_style const has_denorm = std::denorm_present; - static bool const has_denorm_loss = true; - static std::float_round_style const round_style = std::round_to_nearest; - static bool const is_iec559 = true; - static bool const is_bounded = true; - static bool const is_modulo = false; - static int const digits = 10; - - /// Least positive value - CUTLASS_HOST_DEVICE - static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } - - /// Minimum finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } - - /// Maximum finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } - - /// Returns maximum rounding error - CUTLASS_HOST_DEVICE - static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } - - /// Returns positive infinity value - CUTLASS_HOST_DEVICE - static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } - - /// Returns quiet NaN value - CUTLASS_HOST_DEVICE - static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } - - /// Returns signaling NaN value - CUTLASS_HOST_DEVICE - static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } - - /// Returns smallest positive subnormal value - CUTLASS_HOST_DEVICE - static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -}; -} // namespace std -#endif - -namespace cutlass { -namespace platform { - -/// Forward Declaration -template -struct numeric_limits; - -/// Numeric limits -template <> -struct numeric_limits { - static bool const is_specialized = true; - static bool const is_signed = true; - static bool const is_integer = false; - static bool const is_exact = false; - static bool const has_infinity = true; - static bool const has_quiet_NaN = true; - static bool const has_signaling_NaN = false; -#if !defined(__CUDACC_RTC__) - static std::float_denorm_style const has_denorm = std::denorm_present; -#endif - static bool const has_denorm_loss = true; -#if !defined(__CUDACC_RTC__) - static std::float_round_style const round_style = std::round_to_nearest; -#endif - static bool const is_iec559 = true; - static bool const is_bounded = true; - static bool const is_modulo = false; - static int const digits = 10; - - /// Least positive value - CUTLASS_HOST_DEVICE - static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } - - /// Minimum finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } - - /// Maximum finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } - - /// Returns smallest finite value - CUTLASS_HOST_DEVICE - static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } - - /// Returns maximum rounding error - CUTLASS_HOST_DEVICE - static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } - - /// Returns positive infinity value - CUTLASS_HOST_DEVICE - static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } - - /// Returns quiet NaN value - CUTLASS_HOST_DEVICE - static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } - - /// Returns signaling NaN value - CUTLASS_HOST_DEVICE - static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } - - /// Returns smallest positive subnormal value - CUTLASS_HOST_DEVICE - static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -}; -} // namespace platform -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Arithmetic operators -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool operator==(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __heq(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) == float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator!=(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __hne(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) != float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator<(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __hlt(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) < float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator<=(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __hle(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) <= float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator>(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __hgt(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) > float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -bool operator>=(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return __hge(lhs.to_half(), rhs.to_half()); -#else - return float(lhs) >= float(rhs); -#endif -} - -CUTLASS_HOST_DEVICE -half_t operator+(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__hadd(lhs.to_half(), rhs.to_half())); -#else - return half_t(float(lhs) + float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -half_t operator-(half_t const& lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__hneg(lhs.to_half())); -#else - return half_t(-float(lhs)); -#endif -} - -CUTLASS_HOST_DEVICE -half_t operator-(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__hsub(lhs.to_half(), rhs.to_half())); -#else - return half_t(float(lhs) - float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -half_t operator*(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__hmul(lhs.to_half(), rhs.to_half())); -#else - return half_t(float(lhs) * float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -half_t operator/(half_t const& lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return half_t(__hdiv(lhs.to_half(), rhs.to_half())); -#else - return half_t(float(lhs) / float(rhs)); -#endif -} - -CUTLASS_HOST_DEVICE -half_t& operator+=(half_t & lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hadd(lhs.to_half(), rhs.to_half())); -#else - lhs = half_t(float(lhs) + float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t& operator-=(half_t & lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hsub(lhs.to_half(), rhs.to_half())); -#else - lhs = half_t(float(lhs) - float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t& operator*=(half_t & lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hmul(lhs.to_half(), rhs.to_half())); -#else - lhs = half_t(float(lhs) * float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t& operator/=(half_t & lhs, half_t const& rhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hdiv(lhs.to_half(), rhs.to_half())); -#else - lhs = half_t(float(lhs) / float(rhs)); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t& operator++(half_t & lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -#else - float tmp(lhs); - ++tmp; - lhs = half_t(tmp); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t& operator--(half_t & lhs) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -#else - float tmp(lhs); - --tmp; - lhs = half_t(tmp); -#endif - return lhs; -} - -CUTLASS_HOST_DEVICE -half_t operator++(half_t & lhs, int) { - half_t ret(lhs); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -#else - float tmp(lhs); - tmp++; - lhs = half_t(tmp); -#endif - return ret; -} - -CUTLASS_HOST_DEVICE -half_t operator--(half_t & lhs, int) { - half_t ret(lhs); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -#else - float tmp(lhs); - tmp--; - lhs = half_t(tmp); -#endif - return ret; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// User-defined literals -// - -CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(long double x) { - return cutlass::half_t(float(x)); -} - -CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(unsigned long long int x) { - return cutlass::half_t(int(x)); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/integer_subbyte.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/integer_subbyte.h deleted file mode 100644 index 43047eaeec355b8c13ce034ffa7d508f083e823b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/integer_subbyte.h +++ /dev/null @@ -1,301 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines a class for using integer types smaller than one byte in host or - device code. -*/ - -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#endif - -#include "cutlass/numeric_size.h" -#include "cutlass/platform/platform.h" - -namespace cutlass { - -template -struct integer_subbyte { - using Storage = uint8_t; - - static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); - - // "External type"; the integer type for which - // integer_subbyte has a conversion-to operator - using xint_t = typename cutlass::platform::conditional::type; - - // Bitmask for truncation from larger integers - static constexpr Storage bits_mask_ = Storage(Storage(-1) >> (8 - Bits)); - // Bitmask for the sign bit - static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1)); - - // Where the bits are stored - Storage storage; - - // Default construction does NOT zero-initialize - integer_subbyte() = default; - - // Implicit conversion is DEPRECATED. - // Please use one of the two explicit constructors below. - template> - > -#if !defined(CUTLASS_EXTRA_WARNINGS) - [[deprecated("Implicit conversion is deprecated; please use explicit construction instead")]] -#endif - CUTLASS_HOST_DEVICE - integer_subbyte(T value) - : integer_subbyte(static_cast(value)) {} - - CUTLASS_HOST_DEVICE - integer_subbyte(float value) - : integer_subbyte(static_cast(value)) {} - - // CUTLASS code commonly converts both signed and unsigned integers - // into integer_subbyte, so the class provides both explicit - // conversions. - - // Precondition: If the external type is unsigned int, then value - // fits in unsigned int (is nonnegative). - CUTLASS_HOST_DEVICE explicit - integer_subbyte(int value) - : storage(reinterpret_cast(value) & bits_mask_) - { - if constexpr (Signed) { - [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); - [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; - assert(value >= lower_bound); - assert(value <= upper_bound); - } - else { - [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; - assert(value >= 0); - assert(value < static_cast(upper_bound)); - } - } - - // Precondition: If the external type is (signed) int, then value - // fits in int. - CUTLASS_HOST_DEVICE explicit - integer_subbyte(unsigned value) - : storage(reinterpret_cast(value) & bits_mask_) - { - if constexpr (Signed) { - [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); - [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; - assert(value >= lower_bound); - assert(value <= upper_bound); - } - else { - [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; - assert(value < upper_bound); - } - } - - CUTLASS_HOST_DEVICE explicit - integer_subbyte(uint8_t value) - : integer_subbyte(static_cast(value)) {} - - // Convert to the "external" integer type (int or unsigned) - CUTLASS_HOST_DEVICE - operator xint_t() const { - if (sign_mask_ & storage) { // Sign extend - return xint_t(storage) | ~xint_t(bits_mask_); - } else { - return xint_t(storage); - } - } - - CUTLASS_HOST_DEVICE - bool operator==(integer_subbyte const& rhs) const { - return storage == rhs.storage; - } - - CUTLASS_HOST_DEVICE - bool operator!=(integer_subbyte const& rhs) const { - return storage != rhs.storage; - } - - CUTLASS_HOST_DEVICE - bool operator<(integer_subbyte const& rhs) const { - if ((sign_mask_ & storage) == (sign_mask_ & rhs.storage)) { - // If both *this and rhs have the same sign, compare storage directly. - return storage < rhs.storage; - } - else { - // If *this and rhs don't have the same sign, - // then return whether *this is negative. - return sign_mask_ & storage; - } - } - - CUTLASS_HOST_DEVICE - bool operator<=(integer_subbyte const& rhs) const { - if ((sign_mask_ & storage) == (sign_mask_ & rhs.storage)) { - // If both *this and rhs have the same sign, compare storage directly. - return storage <= rhs.storage; - } - else { - // If *this and rhs don't have the same sign, - // then return whether *this is negative. - return sign_mask_ & storage; - } - } - - CUTLASS_HOST_DEVICE - bool operator>=(integer_subbyte const& rhs) const { - return !(*this < rhs); - } - - CUTLASS_HOST_DEVICE - bool operator>(integer_subbyte const& rhs) const { - return !(*this <= rhs); - } - - CUTLASS_HOST_DEVICE friend integer_subbyte - conj(integer_subbyte const& x) { - return x; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// 1-bit binary type -using bin1_t = bool; - -/// 1-bit Unsigned integer type -using uint1b_t = integer_subbyte<1, false>; - -/// 2-bit Integer type -using int2b_t = integer_subbyte<2, true>; - -/// 2-bit Unsigned integer type -using uint2b_t = integer_subbyte<2, false>; - -/// 3-bit Integer type -using int3b_t = integer_subbyte<3, true>; - -/// 3-bit Unsigned integer type -using uint3b_t = integer_subbyte<3, false>; - -/// 4-bit Integer type -using int4b_t = integer_subbyte<4, true>; - -/// 4-bit Unsigned integer type -using uint4b_t = integer_subbyte<4, false>; - -/// 6-bit integer type -using int6b_t = integer_subbyte<6, true>; - -/// 6-bit unsigned integer type -using uint6b_t = integer_subbyte<6, false>; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct sizeof_bits> { - static constexpr int value = Bits; -}; - -/// Defines the size of an element in bits - specialized for bin1_t -template <> -struct sizeof_bits { - static constexpr int value = 1; -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace platform { - -/// Forward Declaration -template -struct numeric_limits; - -// Specialization for signed integer_subbyte -template -struct numeric_limits> { -private: - using value_type = cutlass::integer_subbyte; - -public: - CUTLASS_HOST_DEVICE static value_type lowest() noexcept { - return value_type{ - -(1 << (NumBits - 1)) - }; - } - - CUTLASS_HOST_DEVICE static value_type max() noexcept { - return value_type{ - (1 << (NumBits - 1)) - 1 - }; - } - - CUTLASS_HOST_DEVICE static value_type const min() noexcept { - return lowest(); - } - - static constexpr bool is_integer = true; - static constexpr bool is_signed = true; - static constexpr bool has_infinity = false; -}; - -// Specialization for unsigned integer_subbyte -template -struct numeric_limits> { -private: - using value_type = cutlass::integer_subbyte; - -public: - CUTLASS_HOST_DEVICE static value_type lowest() noexcept { - return value_type{0u}; - } - - CUTLASS_HOST_DEVICE static value_type max() noexcept { - return value_type{ - (1u << NumBits) - 1u - }; - } - - CUTLASS_HOST_DEVICE static value_type const min() noexcept { - return lowest(); - } - - static constexpr bool is_integer = true; - static constexpr bool is_signed = false; -}; - -} // namespace platform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.h deleted file mode 100644 index 5d7c685f6e830b2cf90611f84ff5f65afc058c17..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.h +++ /dev/null @@ -1,137 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/device_kernel.h" -#if !defined(__CUDACC_RTC__) -#include "cuda_runtime.h" -#include "cutlass/cluster_launch.hpp" -#include "cutlass/trace.h" -#endif -#include - -namespace cutlass { - -struct KernelHardwareInfo { - // - // Data members - // - - // Hardware properties - int device_id = 0; - int sm_count = 0; - - // Kernel properties - int max_active_clusters = 0; // Maximum number of clusters that could co-exist on the target device. - dim3 cluster_shape = {0,0,0}; - dim3 cluster_shape_fallback = {0,0,0}; - - // - // Methods - // - -#if !defined(__CUDACC_RTC__) - static inline int - query_device_multiprocessor_count(int device_id = 0) { - cudaError_t result = cudaGetDevice(&device_id); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaGetDevice() returned error " - << cudaGetErrorString(result)); - return 0; - } - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, - cudaDevAttrMultiProcessorCount, device_id); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaDeviceGetAttribute() returned error " - << cudaGetErrorString(result)); - return 0; - } - return multiprocessor_count; - } - - // Query maximum number of active clusters that could co-exist on the target device - // based on kernel properties such as cluster dims and threadblock dims - static inline int - query_device_max_active_clusters( - dim3 cluster_dims, - uint32_t threads_per_block, - void const* kernel_ptr) { - int max_active_clusters = 0; -#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) - ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config( - cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}); - // Given the kernel function and launch configuration, return the maximum number of clusters that could co-exist on the target device. - cudaError_t result = cudaOccupancyMaxActiveClusters(&max_active_clusters, kernel_ptr, &cluster_launch_config.launch_config); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaGetDevice() returned error " - << cudaGetErrorString(result)); - return 0; - } - CUTLASS_TRACE_HOST("cudaOccupancyMaxActiveClusters: maximum number of clusters that could co-exist on the target device = " - << max_active_clusters << "\n"); - return max_active_clusters; -#else - CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster occupancy query."); - return max_active_clusters; -#endif - } - - // Simpler version of the above query function that fetches relevant information from the Kernel - template - static inline int - query_device_max_active_clusters() { - dim3 cluster_dims(cute::size<0>(typename Kernel::ClusterShape{}), - cute::size<1>(typename Kernel::ClusterShape{}), - cute::size<2>(typename Kernel::ClusterShape{})); - uint32_t threads_per_block = Kernel::MaxThreadsPerBlock; - void const* kernel_ptr = (void*)(device_kernel); - return query_device_max_active_clusters(cluster_dims, threads_per_block, kernel_ptr); - } - - template - static inline KernelHardwareInfo - make_kernel_hardware_info(int const device_id = 0, int sm_count = 0, int max_active_clusters = 0) { - if (sm_count == 0) { - sm_count = query_device_multiprocessor_count(device_id); - } - if (max_active_clusters == 0) { - max_active_clusters = query_device_max_active_clusters(); - } - return {device_id, sm_count, max_active_clusters}; - } -#endif -}; - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.hpp deleted file mode 100644 index e1758eac060aae26ccd8dd36fb06db71ff354bb6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_hardware_info.hpp +++ /dev/null @@ -1,35 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -// Simply import .h version of header so as to avoid breaking any existing CUTLASS builds -// after .hpp was changed to .h -#include "cutlass/kernel_hardware_info.h" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_launch.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_launch.h deleted file mode 100644 index e92e6c13f51315316051dabadc635de25bbbae90..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/kernel_launch.h +++ /dev/null @@ -1,142 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines structures and helpers to launch CUDA kernels within CUTLASS. -*/ - -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/trace.h" -#include "cutlass/device_kernel.h" // cutlass::device_kernel - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure containing the basic launch configuration of a CUDA kernel. -struct KernelLaunchConfiguration { - - /// CUDA grid dimensions - dim3 grid; - - /// CUDA threablock dimensions - dim3 block; - - /// Bytes of dynamically allocated SMEM in addition to static SMEM - size_t dynamic_smem; - - // - // Methods - // - - /// Constructs a KernellaunchConfiguration object - CUTLASS_HOST_DEVICE - KernelLaunchConfiguration( - dim3 _grid = dim3(1,1,1), - dim3 _block = dim3(1,1,1), - size_t _dynamic_smem = 0 - ): - grid(_grid), - block(_block), - dynamic_smem(_dynamic_smem) { } -}; - - -template -Status kernel_launch( - dim3 const grid_dims, - dim3 const block_dims, - size_t const smem_size, - cudaStream_t cuda_stream, - const Params &kernel_params, - bool launch_with_pdl) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::kernel_launch"); -#endif - - if (not launch_with_pdl) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::kernel_launch: No PDL"); -#endif - device_kernel<<>>(kernel_params); - } - else { -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) - if constexpr (GemmKernel::ArchTag::kMinComputeCapability < 90) { - CUTLASS_TRACE_HOST(" Programmatic dependent launch (PDL) is only supported for SM90."); - return Status::kInvalid; - } - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - - config.gridDim = grid_dims; - config.blockDim = block_dims; - config.dynamicSmemBytes = smem_size; - config.stream = cuda_stream; - - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = 1; - config.numAttrs = 1; - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::kernel_launch: Calling cudaLaunchKernelEx"); -#endif - cudaError_t launch_result = cudaLaunchKernelEx(&config, &device_kernel, kernel_params); - if (cudaSuccess != launch_result) { - CUTLASS_TRACE_HOST("cutlass::kernel_launch: cudaLaunchKernelEx failed with error: " << cudaGetErrorString(launch_result)); - return Status::kErrorInternal; - } -#else - CUTLASS_TRACE_HOST(" Programmatic dependent launch (PDL) is only supported starting CUDA 11.8."); - return Status::kInvalid; -#endif - } - - cudaError_t result = cudaGetLastError(); - if (cudaSuccess == result) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::kernel_launch: cudaGetLastError reports success"); -#endif - return Status::kSuccess; - } - else { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); - return Status::kErrorInternal; - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/layout.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/layout.h deleted file mode 100644 index b2e377c21339ff6c71d45370fa0572bf15c3f415..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/layout.h +++ /dev/null @@ -1,64 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by TensorRef and derived classes. - - Layout functions map logical coordinates to linear memory. They often require additional - data to describe strides between elements. - - Layout functions must implement all members in the public interface of IdentityTensorLayout<> - defined in cutlass/tensor_ref.h. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/vector.h" - -#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace layout { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/matrix.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/matrix.h deleted file mode 100644 index 281b668ba59e3ddd7a1861e995ba7def13b83df2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/matrix.h +++ /dev/null @@ -1,1349 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by TensorRef and derived classes. - - Layout functions map logical coordinates to linear memory. They often require additional - data to describe strides between elements. - - Layout functions must implement all members in the public interface of IdentityTensorLayout<> - defined in cutlass/tensor_ref.h. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/pitch_linear_coord.h" - -namespace cutlass { -namespace layout { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Defines data layouts of various matrix formats usable by TensorRef and other classes. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for row-major matrices. -class RowMajor { -public: - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - RowMajor(LongIndex ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajor(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajor packed(MatrixCoord const &extent) { - return RowMajor(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column(); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - return MatrixCoord(Index(offset / stride_[0]), Index(offset % stride_[0])); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return LongIndex(extent.row()) * LongIndex(stride_[0]); - } -}; - -/// Mapping function for column-major matrices. -class ColumnMajor { -public: - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajor(LongIndex ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajor(Stride stride): stride_(stride) { } - - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajor packed(MatrixCoord const &extent) { - return ColumnMajor(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row(); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - return MatrixCoord(Index(offset % stride_[0]), Index(offset / stride_[0])); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return LongIndex(extent.column()) * LongIndex(stride_[0]); - } -}; - -/// Mapping function for interleaved matrices. Matrix is structured -/// as row-major arrangement of fixed-size columns. -template -struct RowMajorInterleaved { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - /// Size of interleaved columns - static int const kInterleave = Interleave; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorInterleaved(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorInterleaved packed(MatrixCoord const &extent) { - return RowMajorInterleaved(extent.column() * kInterleave); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - Index row_major = coord.row() / kInterleave; - Index row_minor = coord.row() % kInterleave; - return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor; - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - - Index row_major = Index(offset / stride_[0]); - Index residual = Index(offset % stride_[0]); - - Index column = residual / kInterleave; - Index row_minor = residual % kInterleave; - - return MatrixCoord(row_major * kInterleave + row_minor, column); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return (extent.row() + kInterleave - 1) / kInterleave * stride_[0]; - } -}; - -/// Mapping function for interleaved matrices. Matrix is structured -/// as column-major arrangement of fixed-size rows. -template -struct ColumnMajorInterleaved { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - /// Size of interleaved columns - static int const kInterleave = Interleave; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorInterleaved(Stride stride): stride_(stride) { } - - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorInterleaved packed(MatrixCoord const &extent) { - return ColumnMajorInterleaved(extent.row() * kInterleave); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - Index column_major = coord.column() / kInterleave; - Index column_minor = coord.column() % kInterleave; - return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor; - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - - Index column_major = Index(offset / stride_[0]); - Index residual = Index(offset % stride_[0]); - - Index row = residual / kInterleave; - Index column_minor = residual % kInterleave; - - return MatrixCoord(row, column_major * kInterleave + column_minor); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return (extent.column() + kInterleave - 1) / kInterleave * stride_[0]; - } -}; - -/// Enumerated type for canonical pitch-linear matrix layouts -enum class Matrix { - kColumnMajor, ///< leading dimension refers to stride between columns; stride along rows is 1 - kRowMajor ///< leading dimension refers to stride between rows; stride along columns is 1 -}; - -/// Mapping function for scenario in which layout is row-major or column-major but this information -/// is only available at runtime. -struct ContiguousMatrix { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - - /// Enumerated type indicating canonical matrix layout - Matrix layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ContiguousMatrix( - Index ldm = 0, - Matrix layout = Matrix::kColumnMajor - ): - stride_(ldm), layout_(layout) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ContiguousMatrix packed( - MatrixCoord const &extent, - Matrix layout = Matrix::kColumnMajor) { - - Index ldm = 0; - if (layout == Matrix::kColumnMajor) { - ldm = extent.row(); - } - else if (layout == Matrix::kRowMajor) { - ldm = extent.column(); - } - return ContiguousMatrix(ldm, layout); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - if (layout_ == Matrix::kColumnMajor) { - return coord.row() + coord.column() * stride_[0]; - } - else if (layout_ == Matrix::kRowMajor) { - return coord.row() * stride_[0] + coord.column(); - } - else { - // degenerate case - return 0; - } - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - CUTLASS_UNUSED(offset); - return MatrixCoord(0, 0); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - if (layout_ == Matrix::kColumnMajor) { - return stride_[0] * extent.column(); - } - else if (layout_ == Matrix::kRowMajor) { - return stride_[0] * extent.row(); - } - else { - // degenerate case - return 0; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for scenario in which both rows and columns are separated by a stride. -template -struct AffineRankN { - - /// Logical rank of tensor - static int const kRank = Rank; - - /// Rank of stride vector - static int const kStrideRank = kRank; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Coord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRankN( - Stride const &stride = Stride() - ): - stride_(stride) { } - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRankN( - Coord const &stride_m, - Coord const &stride_n - ) { - - // Concatenate the strides - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kRank/2; ++m) { - stride_[m] = stride_m[m]; - } - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kRank/2; ++n) { - stride_[n + kRank/2] = stride_n[n]; - } - } - - /// Ctor for N = 2 - CUTLASS_HOST_DEVICE - AffineRankN( - LongIndex const &stride_m, - LongIndex const &stride_n - ) { - stride_[0] = stride_m; - stride_[1] = stride_n; - } - - /// Ctor for N = 2 - CUTLASS_HOST_DEVICE - AffineRankN( - LongIndex const &stride - ) { - stride_[0] = stride; - stride_[1] = 1; - } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static AffineRankN packed(TensorCoord const &extent) { - - AffineRankN layout; - layout.stride_[kRank - 1] = 1; - - CUTLASS_PRAGMA_UNROLL - for (int i = kRank - 1; i > 0; --i) { - layout.stride_[i - 1] = layout.stride_[i] * extent[i]; - } - - return layout; - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return dot(coord, stride_); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - return TensorCoord(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - int idx = stride_.max_dim_index(); - return extent[idx] * stride_[idx]; - } -}; - -/// Mapping function for scenario in which both rows and columns are separated by a stride. -/// Row stride is smaller than column stride in AffineRank2ColumnMajor. -struct AffineRank2ColumnMajor { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 2; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2ColumnMajor( - Stride const &stride = Stride() - ): - stride_(stride) { } - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2ColumnMajor( - LongIndex row_stride, ///< stride between elements in consecutive rows - LongIndex column_stride ///< stride between elements in consecutive columns - ) - { stride_[0] = row_stride; stride_[1] = column_stride;} - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2ColumnMajor( - LongIndex stride - ) - { stride_[0] = 1; stride_[1] = stride;} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static AffineRank2ColumnMajor packed(MatrixCoord const &extent) { - return AffineRank2ColumnMajor(1, extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return dot(coord, stride_); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - CUTLASS_UNUSED(offset); - return MatrixCoord(0, 0); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return extent.column() * stride_[1]; - } -}; - -/// Mapping function for scenario in which both rows and columns are separated by a stride. -/// Column stride is smaller than row stride in AffineRank2RowMajor. -struct AffineRank2RowMajor { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 2; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2RowMajor( - Stride const &stride = Stride() - ): - stride_(stride) { } - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2RowMajor( - LongIndex row_stride, ///< stride between elements in consecutive rows - LongIndex column_stride ///< stride between elements in consecutive columns - ) { stride_[0] = row_stride; stride_[1] = column_stride;} - - /// Ctor - CUTLASS_HOST_DEVICE - AffineRank2RowMajor( - LongIndex stride - ) { stride_[0] = stride; stride_[1] = 1;} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static AffineRank2RowMajor packed(MatrixCoord const &extent) { - return AffineRank2RowMajor(1, extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return dot(coord, stride_); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - CUTLASS_UNUSED(offset); - return MatrixCoord(0, 0); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return extent.row() * stride_[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Utility functions to convert stride_factor to the strides used by the Affine2 layout. -// -// stride_factor is the logical distance between two coorinates. -// -// All Coodinates used here are matrix coordinates. stride[0] and extent[0] are for the -// rows. stride[1] and extent[1] are for the columns. -template - struct Affine2Layout_Factory { - CUTLASS_HOST_DEVICE - static Affine2Layout layout_factory(cutlass::Coord<2> const &extent, typename Affine2Layout::Stride stride_factor) { - return Affine2Layout::packed(extent); - } -}; - -template <> -struct Affine2Layout_Factory { -CUTLASS_HOST_DEVICE -static cutlass::layout::AffineRank2ColumnMajor layout_factory( - cutlass::Coord<2> const &extent, - typename cutlass::layout::AffineRank2ColumnMajor::Stride stride_factor) { - return cutlass::layout::AffineRank2ColumnMajor({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); - } -}; - -template <> -struct Affine2Layout_Factory { -CUTLASS_HOST_DEVICE -static cutlass::layout::AffineRank2RowMajor layout_factory( - cutlass::Coord<2> const &extent, - typename cutlass::layout::AffineRank2RowMajor::Stride stride_factor) { - return cutlass::layout::AffineRank2RowMajor({ stride_factor[0] * stride_factor[1] * extent[1], stride_factor[1] }); - } -}; - -// The base layout cutlass::layout::AffineRankN<2> is similar to AffineRank2ColumnMajor -template <> -struct Affine2Layout_Factory> { -CUTLASS_HOST_DEVICE -static cutlass::layout::AffineRankN<2> layout_factory( - cutlass::Coord<2> const &extent, - typename cutlass::layout::AffineRankN<2>::Stride stride_factor) { - return cutlass::layout::AffineRankN<2>({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for block-linear matrices. Matrix is structured -/// as column-major arrangement of 2D tiles (that are column-major). -template -struct ColumnMajorBlockLinear { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - /// Size of a block in rows - static int const kBlockRows = BlockRows; - - /// Size of a block in columns - static int const kBlockColumns = BlockColumns; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorBlockLinear(Index ldm = 0): stride_(ldm) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorBlockLinear packed(MatrixCoord const &extent) { - return ColumnMajorBlockLinear(extent.row() * kBlockRows * kBlockColumns); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return - (coord.row() % kBlockRows) + - (coord.column() % kBlockColumns) * kBlockRows + - (coord.row() / kBlockRows) * kBlockRows * kBlockColumns + - (coord.column() / kBlockColumns) * stride_[0]; - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - - return MatrixCoord(0, 0); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return (extent.column() + kBlockColumns - 1) / kBlockColumns * stride_[0]; - } -}; - -/// Mapping function for block-linear matrices. Matrix is structured -/// as row-major arrangement of 2D tiles (that are row-major) -template -struct RowMajorBlockLinear { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - /// Size of a block in rows - static int const kBlockRows = BlockRows; - - /// Size of a block in columns - static int const kBlockColumns = BlockColumns; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorBlockLinear(Index ldm = 0): stride_(ldm) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorBlockLinear packed(MatrixCoord const &extent) { - return RowMajorBlockLinear(extent.column() * kBlockRows * kBlockColumns); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - return - (coord.column() % kBlockColumns) + - (coord.row() % kBlockRows) * kBlockColumns + - (coord.column() / kBlockColumns) * kBlockRows * kBlockColumns + - (coord.row() / kBlockRows) * stride_[0]; - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - MatrixCoord inverse(LongIndex offset) const { - return MatrixCoord(0, 0); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - return (extent.row() + kBlockRows - 1) / kBlockRows * stride_[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct GeneralMatrix { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 2; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - Matrix layout_id_; - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { } - - /// Ctor - CUTLASS_HOST_DEVICE - GeneralMatrix( - Matrix layout_id, - Index ldm, - Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static GeneralMatrix packed( - MatrixCoord const &extent, - Matrix layout_id = Matrix::kColumnMajor, - Index interleave = 1) { - - Index c; - if (layout_id == Matrix::kRowMajor) { - c = extent.column(); - } - else { - c = extent.row(); - } - - Index ldm = c * interleave; - - return GeneralMatrix(layout_id, ldm, interleave); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord const &coord) const { - Index c, s; - if (layout_id_ == Matrix::kRowMajor) { - c = coord.column(); - s = coord.row(); - } - else { - s = coord.column(); - c = coord.row(); - } - - Index v = s / stride_[1]; - Index residual = (s % stride_[1]); - - return LongIndex(c) * LongIndex(stride_[1]) + LongIndex(v) * LongIndex(stride_[0]) + residual; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - CUTLASS_HOST_DEVICE - Matrix layout_id() const { - return layout_id_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - CUTLASS_HOST_DEVICE - Matrix & layout_id() { - return layout_id_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index stride(int idx) const { - return stride_[idx]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - typename Stride::Index & stride(int idx) { - return stride_[idx]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(MatrixCoord const &extent) const { - Index s; - if (layout_id_ == Matrix::kRowMajor) { - s = extent.row(); - } - else { - s = extent.column(); - } - - Index v = Index((s + stride_[1] - 1) / stride_[1]); - return LongIndex(v) * LongIndex(stride_[0]); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines transposes of matrix layouts -template -struct LayoutTranspose; - -/// Transpose of row-major is column-major -template <> -struct LayoutTranspose { - using type = layout::ColumnMajor; -}; - -/// Transpose of column-major is row-major -template <> -struct LayoutTranspose { - using type = layout::RowMajor; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/permute.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/permute.h deleted file mode 100644 index 99e3353f7ba0be2fef2a4a9c475e3babe0b70058..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/permute.h +++ /dev/null @@ -1,824 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by GEMM+permute path for common tensor or matrix formats. - - Like Layout functions, permute layout functions map logical coordinates to linear memory. They often require additional - data to describe strides between elements. - - Permute layout functions must implement all members in the interface of NoPermute<> defined in this file. Address offset - computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. -*/ -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) -#include "cutlass/fast_math.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/coord.h" -#include "cutlass/tensor_coord.h" - -namespace cutlass { -namespace layout { - -// template -// struct PermuteSelect { -// // Try to give a reasonable error message to the user -// static_assert(!platform::is_same::value, // aka always_false -// "You've tried to use a layout permutation for which the implementation is not availble. " -// "In order to provide an implementation for a particular combination of matrix layout " -// "and direction (direct/inverse), please specialize PermuteSelect trait."); -// }; - -// Base template for defining specializations of permutation inverses -template -struct InversePermute -{ - // Try to give a reasonable error message to the user - static_assert(!platform::is_same::value, // aka always_false - "To apply permutation to a GEMM input operand (A or B), an inverse permutation for the desired " - "permute class must be defined and enabled by specializing cutlass::layout::InversePermute trait."); -}; - -class PermuteBase { -public: - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; -}; - -class NoPermute : public PermuteBase { -public: - // - // Methods - // - - /// Constructor from matrix extent - CUTLASS_HOST_DEVICE - NoPermute(MatrixCoord extent, Index stride) { }; - - /// Constructor from pitch-linear extent - CUTLASS_HOST_DEVICE - NoPermute(PitchLinearCoord extent, Index stride) { }; - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { return 0; } // not correct but should never be called - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { return 0; } // not correct but should never be called -}; - -template<> -struct InversePermute { - using type = NoPermute; -}; - -/// Helper trait to detect if permute operation is a noop -template -inline bool constexpr is_trivial_permute = platform::is_same::value; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Defines permute layouts of various tensor formats. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Tensor4DPermute0213 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Permute layout function for 4-D permuted tensors with matrix (dimensions [M, N]) reshaped -/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. -template -class Tensor4DPermute0213RowMajor : public PermuteBase { -private: - // - // Data members - // - - Index D3_; - - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermute0213RowMajor(MatrixCoord extent, Index stride) { - - assert(extent.row() % D1 == 0); - assert(extent.column() % D2 == 0); - - D3_ = extent.column() / D2; - - stride_ = stride * D1 / D2; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermute0213RowMajor(PitchLinearCoord extent, Index stride) - : Tensor4DPermute0213RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // [i,j,k,l] -> [i,k,j,l] - Index l = coord.column() % D3_; - Index k = coord.column() / D3_; - Index j = coord.row() % D1; - Index i = coord.row() / D1; - - MatrixCoord permuted{k + i * D2, l + j * D3_}; - - return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.strided(), coord.contiguous())); - } -}; - -// Inverse for Tensor4DPermute0213 can be implemented by simply swapping D1 and D2 -template -class Tensor4DPermute0213RowMajorInverse : public Tensor4DPermute0213RowMajor { -public: - using Base = Tensor4DPermute0213RowMajor; - using Base::Base; -}; - -template -struct InversePermute> { - using type = Tensor4DPermute0213RowMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor4DPermute0213RowMajor; -}; - -/// Permute layout function for 4-D permuted tensors with matrix (dimensions [M, N]) reshaped -/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. -template -class Tensor4DPermute0213ColumnMajor : public PermuteBase { -private: - // - // Data members - // - - Index D0_; - - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermute0213ColumnMajor(MatrixCoord extent, Index stride) { - - assert(extent.row() % D1 == 0); - assert(extent.column() % D2 == 0); - - D0_ = extent.row() / D1; - - stride_ = stride * D2 / D1; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermute0213ColumnMajor(PitchLinearCoord extent, Index stride) - : Tensor4DPermute0213ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // [i,j,k,l] -> [i,k,j,l] - Index l = coord.column() / D2; - Index k = coord.column() % D2; - Index j = coord.row() / D0_; - Index i = coord.row() % D0_; - - MatrixCoord permuted{i + k * D0_, j + l * D1}; - - return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.contiguous(), coord.strided())); - } -}; - -// Inverse for Tensor4DPermute0213 can be implemented by simply swapping D1 and D2 -template -class Tensor4DPermute0213ColumnMajorInverse : public Tensor4DPermute0213ColumnMajor { -public: - using Base = Tensor4DPermute0213ColumnMajor; - using Base::Base; -}; - -template -struct InversePermute> { - using type = Tensor4DPermute0213ColumnMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor4DPermute0213ColumnMajor; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Tensor4DPermuteBMM0213 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimensions [B, M, N]) reshaped -/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor. -template -class Tensor4DPermuteBMM0213RowMajor : public PermuteBase { -private: - // - // Data members - // - - Index D3_; - - Index stride_; - - Index batch_stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213RowMajor(MatrixCoord extent, Index stride) { - - Index D2 = extent.row(); - D3_ = extent.column(); - - stride_ = stride * D1; - batch_stride_ = D2 * stride_; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213RowMajor(PitchLinearCoord extent, Index stride) - : Tensor4DPermuteBMM0213RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // The batch index for BMM - Index BMM_batch_idx = blockIdx.z; - - // [i,j,k,l] -> [i,k,j,l] - Index l = coord.column(); - Index k = coord.row(); - Index j = BMM_batch_idx % D1; - Index i = BMM_batch_idx / D1; - - Index pbatch = i; - MatrixCoord pcoord{k, l + j * D3_}; - - return pbatch * LongIndex(batch_stride_) + pcoord.row() * LongIndex(stride_) + pcoord.column(); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.strided(), coord.contiguous())); - } -}; - -template -class Tensor4DPermuteBMM0213RowMajorInverse : public PermuteBase { -private: - // - // Data members - // - - Index D3_; - - Index stride_; - - Index batch_stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213RowMajorInverse(MatrixCoord extent, Index stride) { - - assert(extent.column() % D1 == 0); - - Index D2 = extent.row(); - D3_ = extent.column() / D1; - - stride_ = stride / D1; - - batch_stride_ = D2 * stride_; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213RowMajorInverse(PitchLinearCoord extent, Index stride) - : Tensor4DPermuteBMM0213RowMajorInverse(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // The batch index for BMM - Index BMM_batch_idx = blockIdx.z; - - // The following assumes grouping [(D0)->batch, (D2)->row, (D1,D3)->col] - Index l = coord.column() % D3_; - Index j = coord.column() / D3_; - Index k = coord.row(); - Index i = BMM_batch_idx; - - // compute original [batch, row, col] index - Index pbatch = j + i * D1; - MatrixCoord pcoord{k, l}; - - return pbatch * LongIndex(batch_stride_) + pcoord.row() * LongIndex(stride_) + pcoord.column(); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.strided(), coord.contiguous())); - } -}; - -template -struct InversePermute> { - using type = Tensor4DPermuteBMM0213RowMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor4DPermuteBMM0213RowMajor; -}; - -/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimensions [B, M, N]) reshaped -/// as [B/D1, D1, M, N]. Then perform permute([0, 3, 2, 1]) on the corresponding whole BMM tensor. -template -class Tensor4DPermuteBMM0321ColumnMajor : public PermuteBase { -private: - // - // Data members - // - - Index D2_; - - Index stride_; - - Index batch_stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0321ColumnMajor(MatrixCoord extent, Index stride) { - - D2_ = extent.row(); - Index D3 = extent.column(); - - stride_ = stride * D1; - batch_stride_ = stride_ * D3; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0321ColumnMajor(PitchLinearCoord extent, Index stride) - : Tensor4DPermuteBMM0321ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - Index BMM_batch_idx = blockIdx.z; - - // [i,j,k,l] -> [i,k,j,l] - Index l = coord.column(); - Index k = coord.row(); - Index j = BMM_batch_idx % D1; - Index i = BMM_batch_idx / D1; - - Index pbatch = i; - MatrixCoord pcoord{k + j * D2_, l}; - - return pbatch * LongIndex(batch_stride_) + pcoord.row() + pcoord.column() * LongIndex(stride_); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.contiguous(), coord.strided())); - } -}; - -template -class Tensor4DPermuteBMM0321ColumnMajorInverse : public PermuteBase { -private: - // - // Data members - // - - Index D2_; - - Index stride_; - - Index batch_stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0321ColumnMajorInverse(MatrixCoord extent, Index stride) { - - assert(extent.row() % D1 == 0); - - D2_ = extent.row() / D1; - Index D3 = extent.column(); - - stride_ = stride / D1; - batch_stride_ = stride_ * D3; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0321ColumnMajorInverse(PitchLinearCoord extent, Index stride) - : Tensor4DPermuteBMM0321ColumnMajorInverse(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - Index BMM_batch_idx = blockIdx.z; - - // The following assumes grouping [(D0)->batch, (D1,D2)->row, (D3)->col] - Index l = coord.column(); - Index k = coord.row() % D2_; - Index j = coord.row() / D2_; - Index i = BMM_batch_idx; - - Index pbatch = i * D1 + j; - MatrixCoord pcoord{k, l}; - - return pbatch * LongIndex(batch_stride_) + pcoord.row() + pcoord.column() * LongIndex(stride_); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.contiguous(), coord.strided())); - } -}; - -template -struct InversePermute> { - using type = Tensor4DPermuteBMM0321ColumnMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor4DPermuteBMM0321ColumnMajor; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Tensor5DPermute20314 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -template -class Tensor5DPermute20314RowMajor : public PermuteBase { -private: - // - // Data members - // - - Index T0_; - - Index T4_; - - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute20314RowMajor(MatrixCoord extent, Index stride) { - - assert(extent.row() % T1 == 0); - assert(extent.column() % (T2 * T3) == 0); - - T0_ = extent.row() / T1; - T4_ = extent.column() / (T2 * T3); - - /// Update stride_permute with stride - stride_ = stride / T2 * T1; // stride in Elements - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute20314RowMajor(PitchLinearCoord extent, Index stride) - : Tensor5DPermute20314RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X - // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. - - Index m = coord.column() % T4_; - Index l = (coord.column() / T4_) % T3; - Index k = (coord.column() / T4_) / T3; - Index j = coord.row() % T1; - Index i = coord.row() / T1; - - MatrixCoord permuted{i + k * T0_, m + j * T4_ + l * T1 * T4_}; - - return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.strided(), coord.contiguous())); - } -}; - -/// Inverse for Tensor5DPermute20314 (could also be given a proper name, e.g. Tensor5DPermute13024). -template -class Tensor5DPermute20314RowMajorInverse : public PermuteBase { -private: - // - // Data members - // - - Index T0_; - - Index T4_; - - // Permuted stride in units of elements - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute20314RowMajorInverse(MatrixCoord extent, Index stride) { - - assert(extent.row() % T2 == 0); - assert(extent.column() % (T1 * T3) == 0); - - T0_ = extent.row() / T2; - T4_ = extent.column() / (T1 * T3); - - stride_ = stride / T1 * T2; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute20314RowMajorInverse(PitchLinearCoord extent, Index stride) - : Tensor5DPermute20314RowMajorInverse(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - - /// Computes the offset after the inverse of permute operation in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - Index m = coord.column() % T4_; - Index j = (coord.column() / T4_) % T1; - Index l = (coord.column() / T4_) / T1; - Index i = coord.row() % T0_; - Index k = coord.row() / T0_; - - MatrixCoord permuted{j + i * T1, m + l * T4_ + k * T3 * T4_}; - - return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.strided(), coord.contiguous())); - } -}; - -template -struct InversePermute> { - using type = Tensor5DPermute20314RowMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor5DPermute20314RowMajor; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Tensor5DPermute02413 -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Permute layout function for 5-D permuted tensors with matrix (dimensions [M, N]) reshaped -/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([0, 2, 4, 1, 3]) on the corresponding tensor. -template -class Tensor5DPermute02413ColumnMajor : public PermuteBase { -private: - // - // Data members - // - - Index T0_; - - Index T4_; - - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute02413ColumnMajor(MatrixCoord extent, Index stride) { - - assert(extent.row() % T1 == 0); - assert(extent.column() % (T2 * T3) == 0); - - T0_ = extent.row() / T1; - T4_ = extent.column() / (T2 * T3); - - /// Update stride_permute with stride - stride_ = stride / T1 * T2; // stride in Elements - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute02413ColumnMajor(PitchLinearCoord extent, Index stride) - : Tensor5DPermute02413ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X - // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T0, T2, T4, T1, T3]. - - Index m = (coord.column() / T2) / T3; - Index l = (coord.column() / T2) % T3; - Index k = coord.column() % T2; - Index j = coord.row() / T0_; - Index i = coord.row() % T0_; - - MatrixCoord permuted{i + k * T0_, m + j * T4_ + l * T4_ * T1}; - - return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.contiguous(), coord.strided())); - } -}; - -/// Inverse for Tensor5DPermute02413ColumnMajor -template -class Tensor5DPermute02413ColumnMajorInverse : public PermuteBase { -private: - // - // Data members - // - - Index T0_; - - Index T4_; - - // Permuted stride in units of elements - Index stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute02413ColumnMajorInverse(MatrixCoord extent, Index stride) { - - assert(extent.row() % T2 == 0); - assert(extent.column() % (T1 * T3) == 0); - - T0_ = extent.row() / T2; - T4_ = extent.column() / (T1 * T3); - - stride_ = stride / T2 * T1; - } - - /// Constructor - CUTLASS_HOST_DEVICE - Tensor5DPermute02413ColumnMajorInverse(PitchLinearCoord extent, Index stride) - : Tensor5DPermute02413ColumnMajorInverse(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} - - /// Computes the offset after the inverse of permute operation in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord coord) const { - - Index m = coord.column() % T4_; - Index j = (coord.column() / T4_) % T1; - Index l = (coord.column() / T4_) / T1; - Index i = coord.row() % T0_; - Index k = coord.row() / T0_; - - MatrixCoord permuted{i + j * T0_, k + l * T2 + m * T2 * T3}; - - return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); - } - - /// Computes the offset after Permute Op in logical elements - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return operator()(MatrixCoord(coord.contiguous(), coord.strided())); - } -}; - -template -struct InversePermute> { - using type = Tensor5DPermute02413ColumnMajorInverse; -}; - -template -struct InversePermute> { - using type = Tensor5DPermute02413ColumnMajor; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/pitch_linear.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/pitch_linear.h deleted file mode 100644 index 7052de14a2d2614c0d76d1423a3cda126cef6c68..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/pitch_linear.h +++ /dev/null @@ -1,149 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" -#include "cutlass/pitch_linear_coord.h" - -namespace cutlass { -namespace layout { - -template - using PitchLinearShape = cutlass::PitchLinearShape < Contiguous, Strided >; - using PitchLinearCoord = PitchLinearCoord; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for pitch-linear memory -class PitchLinear { -public: - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - PitchLinear(LongIndex ldm = 0): stride_(ldm) { } - - /// Constructor - CUTLASS_HOST_DEVICE - PitchLinear(Stride _stride): stride_(_stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static PitchLinear packed(TensorCoord const &extent) { - return PitchLinear(extent.contiguous()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]); - } - - /// Returns the logical coordinate given an offset. - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex index) const { - return make_Coord( - TensorCoord::Index(index % stride_[0]), - TensorCoord::Index(index / stride_[0]) - ); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - LongIndex stride(int rank) const { - return stride_[rank]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - LongIndex & stride(int rank) { - return stride_[rank]; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent.strided() * stride_[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor.h deleted file mode 100644 index 9e8a354e663e486f58925403829ba10cbd775f76..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor.h +++ /dev/null @@ -1,644 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D - tensor formats. - - Layout functions map logical coordinates to linear memory. They often require additional - data to describe strides between elements. - - Layout functions must implement all members in the public interface of IdentityTensorLayout<> - defined in cutlass/tensor_ref.h. -*/ -#pragma once -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(cassert) -#include "cutlass/fast_math.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/coord.h" -#include "cutlass/tensor_coord.h" - -namespace cutlass { -namespace layout { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Defines data layouts of various tensor formats usable by TensorRef and other classes. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tag used for 3-D NWC tensors for 1-D convolutions; only used in 3.x API -class TensorNWC {}; - -/// Tag used for n-D KCSRT tensors for n-D convolutions; only used in 3.x API for wgrad output layouts -class TensorKCS {}; -class TensorKCSR {}; -class TensorKCSRT {}; - -/// Tag used for n-D CSRTK tensors for n-D convolutions; only used in 3.x API for wgrad output layouts -class TensorCSK {}; -class TensorCSRK {}; -class TensorCSRTK {}; - -/// Mapping function for 4-D NHWC tensors. -class TensorNHWC { -public: - /// Logical rank of tensor - static int const kRank = 4; - - /// Rank of stride vector - static int const kStrideRank = 3; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate (n, h, w, c) - using TensorCoord = Tensor4DCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - [stride_w, stride_h, stride_n] - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNHWC(Stride const &stride = Stride(0)): stride_(stride) { } - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNHWC( - typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates - typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates - typename Stride::Index stride_n ///< number of elements between adjacent N coordinates - ): - stride_(make_Coord(stride_w, stride_h, stride_n)) { } - - /// Constructor - // Once convolutions implement 64b stride this ctor can be deleted - CUTLASS_HOST_DEVICE - TensorNHWC(Coord const &stride): - stride_(make_Coord( - static_cast(stride[0]), - static_cast(stride[1]), - static_cast(stride[2])) - ) { } - - /// Helper returns a layout to a tightly packed NHWC tensor. - CUTLASS_HOST_DEVICE - static TensorNHWC packed(TensorCoord const &extent) { - return TensorNHWC( - make_Coord( - extent.c(), - extent.w() * extent.c(), - extent.h() * extent.w() * extent.c() - ) - ); - } - - /// Returns the offset of a coordinate (n, h, w, c) in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return coord.c() + - LongIndex(stride_[0] * coord.w()) + - LongIndex(stride_[1] * coord.h()) + - LongIndex(stride_[2] * coord.n()); - } - - /// Returns the offset of a pitchlinear coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); - } - - /// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory. - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex index) const { - - int n = 0, h = 0, w = 0, c = 0; - - #if defined(__CUDA_ARCH__) - int tmp = 0; - c = int(index % static_cast(stride_[0])); - - unsigned int hw_mul, hw_shr, w_mul, w_shr, c_mul, c_shr; - - find_divisor(hw_mul, hw_shr, stride_[2]); - find_divisor(w_mul, w_shr, stride_[1]); - find_divisor(c_mul, c_shr, stride_[0]); - - fast_divmod(n, tmp, index, int(stride_[2]), hw_mul, hw_shr); - fast_divmod(h, w, tmp, int(stride_[1]), w_mul, w_shr); - fast_divmod(w, tmp, w, int(stride_[0]), c_mul, c_shr); - #else - - n = int(index / stride_[2]); - LongIndex residual = index % stride_[2]; - - h = int(residual / stride_[1]); - residual = (residual % stride_[1]); - - w = int(residual / stride_[0]); - c = int(residual % stride_[0]); - - #endif - return TensorCoord(n, h, w, c); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - // it does not make sense if the extent is larger than stride - // and we could not rely on the capacity calculation in such cases - // we could move this checkers to debug code only - if ((extent.c() > stride_[0]) - || (extent.w() * stride_[0] > stride_[1]) - || (extent.h() * stride_[1] > stride_[2])) { - assert(0); - } - return extent.n() * stride_[2]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for 4-D NCHW tensors. -class TensorNCHW { -public: - /// Logical rank of tensor - static int const kRank = 4; - - /// Rank of stride vector - static int const kStrideRank = 3; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Tensor4DCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - [w, hw, chw] - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNCHW(Stride const &stride = Stride(0)): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorNCHW packed(TensorCoord const &extent) { - return TensorNCHW( - make_Coord( - extent.w(), - extent.w() * extent.h(), - extent.h() * extent.w() * extent.c() - ) - ); - } - - /// Returns the offset of a coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return coord.w() + - LongIndex(stride_[0] * coord.h()) + - LongIndex(stride_[1] * coord.c()) + - LongIndex(stride_[2] * coord.n()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent.n() * stride_[2]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for 4-D NC/xHWx tensors. -template -class TensorNCxHWx { -public: - - /// Interleaving quantity - static int const kInterleave = Interleave; - - /// Logical rank of tensor - static int const kRank = 4; - - /// Rank of stride vector - static int const kStrideRank = 3; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Tensor4DCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - [Interleave x w, Interleave x wh, hwc] - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNCxHWx(Stride const &stride = Stride(0)): stride_(stride) { } - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNCxHWx( - typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates - typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates - typename Stride::Index stride_n ///< number of elements between adjacent N coordinates - ): - stride_(make_Coord(stride_w, stride_h, stride_n)) { } - - /// Constructor - // Once convolutions implement 64b stride this ctor can be deleted - CUTLASS_HOST_DEVICE - TensorNCxHWx(Coord const &stride): - stride_(make_Coord( - static_cast(stride[0]), - static_cast(stride[1]), - static_cast(stride[2])) - ) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorNCxHWx packed(TensorCoord const &extent) { - return TensorNCxHWx( - make_Coord( - kInterleave * extent.w(), - kInterleave * extent.w() * extent.h(), - extent.h() * extent.w() * extent.c() - ) - ); - } - - /// Returns the offset of a coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - Index c_minor = (coord.c() % kInterleave); - Index c_major = (coord.c() / kInterleave); - - return c_minor + - LongIndex(kInterleave * coord.w()) + - LongIndex(stride_[0] * coord.h()) + - LongIndex(stride_[1] * c_major) + - LongIndex(stride_[2] * coord.n()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent.n() * stride_[2]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for 4-D CxRSKx tensors. -template -class TensorCxRSKx { -public: - - /// Interleaving quantity - static int const kInterleave = Interleave; - - /// Logical rank of tensor - static int const kRank = 4; - - /// Rank of stride vector - static int const kStrideRank = 3; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Tensor4DCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - [Interleave x n, Interleave x nw, Interleave x nwh] - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - TensorCxRSKx(Stride const &stride = Stride(0)): stride_(stride) { } - - /// Constructor - CUTLASS_HOST_DEVICE - TensorCxRSKx( - typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates - typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates - typename Stride::Index stride_n ///< number of elements between adjacent N coordinates - ): - stride_(make_Coord(stride_w, stride_h, stride_n)) { } - - /// Constructor - // Once convolutions implement 64b stride this ctor can be deleted - CUTLASS_HOST_DEVICE - TensorCxRSKx(Coord const &stride): - stride_(make_Coord( - static_cast(stride[0]), - static_cast(stride[1]), - static_cast(stride[2])) - ) { } - - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorCxRSKx packed(TensorCoord const &extent) { - return TensorCxRSKx( - make_Coord( - kInterleave * extent.n(), - kInterleave * extent.n() * extent.w(), - kInterleave * extent.n() * extent.w() * extent.h() - ) - ); - } - - /// Returns the offset of a coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - Index c_minor = (coord.c() % kInterleave); - Index c_major = (coord.c() / kInterleave); - - return c_minor + - LongIndex(kInterleave * coord.n()) + - LongIndex(stride_[0] * coord.w()) + - LongIndex(stride_[1] * coord.h()) + - LongIndex(stride_[2] * c_major); - } - - /// Returns the offset of a pitchlinear coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord const &coord) const { - return (coord.contiguous() % kInterleave) + - LongIndex((coord.contiguous() / kInterleave) * stride_[2]) + - LongIndex(coord.strided() * kInterleave); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return (extent.c() / kInterleave * stride_[2]); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mapping function for 5-D NDHWC tensors. -class TensorNDHWC { -public: - /// Logical rank of tensor - static int const kRank = 5; - - /// Rank of stride vector - static int const kStrideRank = 4; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate (n, d, h, w, c) - using TensorCoord = Tensor5DCoord; - - /// Stride vector - using Stride = Coord; - -private: - // - // Data members - // - - /// Stride data member - [c, wc, hwc, dhwc] - Stride stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { } - - /// Constructor - CUTLASS_HOST_DEVICE - TensorNDHWC( - typename Stride::Index c, - typename Stride::Index wc, - typename Stride::Index hwc, - typename Stride::Index dhwc): - stride_(make_Coord(c, wc, hwc, dhwc)) { } - - /// Constructor - // Once convolutions implement 64b stride this ctor can be deleted - CUTLASS_HOST_DEVICE - TensorNDHWC(Coord const &stride): - stride_(make_Coord( - static_cast(stride[0]), - static_cast(stride[1]), - static_cast(stride[2]), - static_cast(stride[3])) - ) { } - - /// Helper returns a layout to a tightly packed NHWC tensor. - CUTLASS_HOST_DEVICE - static TensorNDHWC packed(TensorCoord const &extent) { - return TensorNDHWC( - make_Coord( - extent.c(), - extent.w() * extent.c(), - extent.h() * extent.w() * extent.c(), - extent.d() * extent.h() * extent.w() * extent.c() - ) - ); - } - - /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return coord.c() + - LongIndex(stride_[0] * coord.w()) + - LongIndex(stride_[1] * coord.h()) + - LongIndex(stride_[2] * coord.d()) + - LongIndex(stride_[3] * coord.n()); - } - - /// Returns the offset of a pitchlinear coordinate in linear memory. - CUTLASS_HOST_DEVICE - LongIndex operator()(PitchLinearCoord coord) const { - return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - // it does not make sense if the extent is larger than stride - // and we could not rely on the capacity calculation in such cases - // we could move this checkers to debug code only - if ((extent.c() > stride_[0]) - || (extent.w() * stride_[0] > stride_[1]) - || (extent.h() * stride_[1] > stride_[2]) - || (extent.d() * stride_[2] > stride_[3])) { - assert(0); - } - return extent.n() * stride_[3]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h deleted file mode 100644 index e4d25a5109c70d15e562881d79a2c384192b0346..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h +++ /dev/null @@ -1,1045 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_coord.h" // cutlass::MatrixCoord - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace layout { - -// template < -// int ElementSize, -// gemm::Operand Operand -// > -// struct VoltaTensorOpMultiplicandCongruous; - -// template < -// int ElementSize, -// gemm::Operand Operand -// > -// struct ColumnMajorVoltaTensorOpMultiplicandCongruous; -// template < -// int ElementSize, -// gemm::Operand Operand -// > -// struct RowMajorVoltaTensorOpMultiplicandCongruous; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -template -struct VoltaTensorOpMultiplicandCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - /// Fundamental tile shape in units of vectors - using TileShape = PitchLinearShape<8, 4>; - - /// Fundamental partition shape in units of vectors - using PartitionShape = PitchLinearShape<8, 2>; - - // - // Static constants - // - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - - using PartitionCount = PitchLinearShape< - TileShape::kContiguous / PartitionShape::kContiguous, - TileShape::kStrided / PartitionShape::kStrided - >; - - using AccessCount = PitchLinearShape< - PartitionShape::kContiguous, - PartitionShape::kStrided - >; - -private: - - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandCongruous(Index ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandCongruous(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static VoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return VoltaTensorOpMultiplicandCongruous(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - // First, compute c and s of vector within source (in units of vector accesses) - int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; - int vec_strided_idx = coord.strided(); - - // Compute the fundamental tile being accessed - int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; - int tile_strided_idx = vec_strided_idx / TileShape::kStrided; - - int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; - int tile_strided_residual = vec_strided_idx % TileShape::kStrided; - - // Then swizzle in a tile - // Swizzle pattern is (tid[2:0] << 2)|(tid[4:3] ^ tid[2:1]) - int permuted_strided_within_tile = (tile_contiguous_residual >> 1); - int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | - ((tile_contiguous_residual & 1) << 2); - // Compute final element location - int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + - permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); - - int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; - - return element_contiguous + element_strided * stride_[0]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -template -struct ColumnMajorVoltaTensorOpMultiplicandCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return ColumnMajorVoltaTensorOpMultiplicandCongruous(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -template -struct RowMajorVoltaTensorOpMultiplicandCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return RowMajorVoltaTensorOpMultiplicandCongruous(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - - -/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -// template -template -struct VoltaTensorOpMultiplicandBCongruous { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - /// Fundamental tile shape in units of vectors - using TileShape = PitchLinearShape<8, 4>; - - /// Fundamental partition shape in units of vectors - using PartitionShape = PitchLinearShape<4, 4>; - - // - // Static constants - // - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - - using PartitionCount = PitchLinearShape< - TileShape::kContiguous / PartitionShape::kContiguous, - TileShape::kStrided / PartitionShape::kStrided - >; - - using AccessCount = PitchLinearShape< - PartitionShape::kContiguous, - PartitionShape::kStrided - >; - -private: - - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandBCongruous(Index ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandBCongruous(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static VoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { - return VoltaTensorOpMultiplicandBCongruous(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - // First, compute c and s of vector within source (in units of vector accesses) - int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; - int vec_strided_idx = coord.strided(); - - // Compute the fundamental tile being accessed - int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; - int tile_strided_idx = vec_strided_idx / TileShape::kStrided; - - int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; - int tile_strided_residual = vec_strided_idx % TileShape::kStrided; - - // Then swizzle in a tile - // Swizzle pattern is (tid[1:0] << 3)|(tid & 0x4)|(tid[1:0]) - int permuted_strided_within_tile = (tile_contiguous_residual & 0x3); - int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | - (tile_contiguous_residual & 0x4); - - // Compute final element location - int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + - permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); - - int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; - - return element_contiguous + element_strided * stride_[0]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -template -struct ColumnMajorVoltaTensorOpMultiplicandBCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandBCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { - return ColumnMajorVoltaTensorOpMultiplicandBCongruous(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -template -struct RowMajorVoltaTensorOpMultiplicandBCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandBCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { - return RowMajorVoltaTensorOpMultiplicandBCongruous(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and KBlock size (in elements). -template -struct VoltaTensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 64b accesses - static int const kAccessSize = 64; - - // - // Static constants - // - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - static int const kKBlock = KBlock; - - private: - // - // Data members - // - - /// Stride data member. For GEMM, it equals to KBlock x stage. - Stride stride_; - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - VoltaTensorOpMultiplicandCrosswise(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static VoltaTensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { - return VoltaTensorOpMultiplicandCrosswise(extent[1]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - // - // First, compute c and s of vector within source (in units of vector - // accesses) - // - int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; - int vec_strided_idx = coord.strided(); - - // - // Then swizzle - // The mapping is like this: - // id[1:0]|(id[3]^id[4])|id[2] - - int vec_strided_within_tile = vec_contiguous_idx & 0x7; - int permuted_vec_contiguous = - (vec_strided_idx & (~0xF)) + (vec_strided_idx & 0x3) * 4 + - (((vec_strided_idx >> 2) ^ ((vec_strided_idx & 0x10) >> 3)) & 0x3); - - permuted_vec_contiguous ^= ((vec_strided_within_tile >> 1) & 0x3); - - int permuted_vec_strided = vec_contiguous_idx; - - // - // Compute final element location - // - - int element_contiguous = permuted_vec_contiguous * kElementsPerAccess + - (coord.contiguous() % kElementsPerAccess); - - return element_contiguous + permuted_vec_strided * (stride_[0] * kElementsPerAccess); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[0] * stride_[0]; - } -}; - -/// Template mapping a column-major view of pitch-linear memory to -/// VoltaTensorOpMultiplicandCrosswise -template -struct ColumnMajorVoltaTensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandCrosswise; - - /// This layout is optimized for 64b accesses - static int const kAccessSize = Base::kAccessSize; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorVoltaTensorOpMultiplicandCrosswise packed( - TensorCoord const &extent) { - return ColumnMajorVoltaTensorOpMultiplicandCrosswise(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicandCrosswise -template -struct RowMajorVoltaTensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = VoltaTensorOpMultiplicandCrosswise; - - /// This layout is optimized for 64b accesses - static int const kAccessSize = Base::kAccessSize; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorVoltaTensorOpMultiplicandCrosswise packed( - TensorCoord const &extent) { - return RowMajorVoltaTensorOpMultiplicandCrosswise(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -} // namespace layout -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h deleted file mode 100644 index 6ca60055e5555eac3c93cf8cd96938e6e2a92e56..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h +++ /dev/null @@ -1,1169 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/layout/pitch_linear.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace layout { - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -/// This one is the base class of all Ampere/Turing fp16/bf16/int8/int4/int1 -/// tensor core kernels. tf32 TN uses this too. -template -struct TensorOpMultiplicand { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Static constants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - static int const kCrosswise = Crosswise; - - /// Contiguous dimension of the tile shape matches one shared memory cache - /// line - 128B. For 128bit access size, it equals to 8 accesses. - static int const kTileShapeContiguous = 128 / (kAccessSize / 8); - - /// Number of kblocks to store PartitionShape::kContiguous Elements - static int const kFactor = - kTileShapeContiguous * kElementsPerAccess / kCrosswise; - - static_assert( - (kFactor > 0), - "kCrosswise should be no large than one shared memory cache line."); - - /// The strided dimension needs to be at least (WarpSize(32) / - /// kTileShapeContiguous) for a warp to access. To ensure conflict free - /// access, it also needs to be at least (kTileShapeContiguous / kFactor). - /// See comments below - static int const kTileShapeStride = - ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous)) - ? (kTileShapeContiguous / kFactor) - : (32 / kTileShapeContiguous); - - /// Fundamental tile shape in units of vectors to guarantee bank conflict free - /// shared memory load/store. - /// For kFactor = 1, TileShape = <8, 8> - /// For kFactor > 1, TileShape = <8, 4> - using TileShape = PitchLinearShape; - - /// Fundamental partition shape in units of vectors - using PartitionShape = PitchLinearShape<4, 4>; - - using PartitionCount = - PitchLinearShape; - - using AccessCount = - PitchLinearShape; - - private: - // - // Data members - // - - /// Stride data member. For GEMM, it equals to kCrosswise x stage. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicand(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicand(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicand packed(TensorCoord const &extent) { - return TensorOpMultiplicand(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - // - // First, compute c and s of vector within source (in units of vector - // accesses) - // - - int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; - int vec_strided_idx = coord.strided() / kFactor; - - // Compute the fundamental tile being accessed - int tile_contiguous_idx = - vec_contiguous_idx / (TileShape::kContiguous / kFactor); - - int tile_contiguous_residual = - vec_contiguous_idx % (TileShape::kContiguous / kFactor) + - ((coord.strided() % kFactor) * (TileShape::kContiguous / kFactor)); - int tile_strided_residual = vec_strided_idx % TileShape::kStrided; - - // Compute the 'partition' within the fundamental tile - int partition_contiguous_idx = - tile_contiguous_residual / PartitionShape::kContiguous; - int partition_strided_idx = - tile_strided_residual / PartitionShape::kStrided; - - int partition_contiguous_residual = - tile_contiguous_residual % PartitionShape::kContiguous; - int partition_strided_residual = - tile_strided_residual % PartitionShape::kStrided; - - // - // Then swizzle - // - - int permuted_vec_contiguous_within_partition = - partition_contiguous_residual ^ (partition_strided_residual % 4); - - int permuted_partition_contiguous_within_tile = - partition_contiguous_idx ^ (partition_strided_idx % 2); - - // - // Compute final element location - // - - int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + - permuted_partition_contiguous_within_tile * - PartitionShape::kContiguous + - permuted_vec_contiguous_within_partition) * - kElementsPerAccess + - (coord.contiguous() % kElementsPerAccess); - - int element_strided = vec_strided_idx; - - return element_contiguous + element_strided * stride_[0] * kFactor; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -template -struct TensorOpMultiplicandCongruous { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicand; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - static int const kCrosswise = Base::kCrosswise; - static int const kFactor = Base::kFactor; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return TensorOpMultiplicandCongruous(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(coord); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return coord; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(extent); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -/// This one is just for TF32 NT kernel. -template -struct TensorOpMultiplicandCongruous<32, Crosswise> { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - /// Fundamental tile shape in units of vectors - using TileShape = PitchLinearShape<8, 4>; - - /// Partitionshape is the same as TileShape for this layout - using PartitionShape = PitchLinearShape<8, 4>; - - using PartitionCount = - PitchLinearShape; - - using AccessCount = - PitchLinearShape; - - // - // Static constants - // - static int const kElementSize = 32; - static int const kElementsPerAccess = kAccessSize / kElementSize; - static int const kCrosswise = Crosswise; - static int const kFactor = 1; - - private: - // - // Data members - // - - /// Stride data member. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return TensorOpMultiplicandCongruous(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - int tc = coord.contiguous() / 32; - int ts = coord.strided() / 4; - - int c = (coord.contiguous() % 32) / kElementsPerAccess; - int s = coord.strided() % 4; - - LongIndex offset = (c ^ (2 * s)) * kElementsPerAccess + s * stride_[0] + - tc * 32 + ts * stride_[0] * 4 + coord.contiguous() % 4; - - return offset; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to -/// TensorOpMultiplicand -template -struct ColumnMajorTensorOpMultiplicandCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - static int const kCrosswise = Base::kCrosswise; - static int const kFactor = Base::kFactor; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicandCongruous(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicand -template -struct RowMajorTensorOpMultiplicandCongruous { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - static int const kCrosswise = Base::kCrosswise; - static int const kFactor = Base::kFactor; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { - return RowMajorTensorOpMultiplicandCongruous(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -template -struct TensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicand; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - static int const kCrosswise = Base::kCrosswise; - static int const kFactor = Base::kFactor; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { - return TensorOpMultiplicandCrosswise(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(coord); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return coord; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(extent); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to -/// TensorOpMultiplicandCrosswise -template -struct ColumnMajorTensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCrosswise; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicandCrosswise packed( - TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicandCrosswise(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicandCrosswise -template -struct RowMajorTensorOpMultiplicandCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCrosswise; - - /// This layout is optimized for 128b accesses - static int const kAccessSize = Base::kAccessSize; - using TileShape = typename Base::TileShape; - using PartitionShape = typename Base::PartitionShape; - - // - // Static constants - // - - static int const kElementSize = Base::kElementSize; - static int const kElementsPerAccess = Base::kElementsPerAccess; - using PartitionCount = typename Base::PartitionCount; - using AccessCount = typename Base::AccessCount; - - private: - // - // Data members - // - - Base layout_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicandCrosswise packed( - TensorCoord const &extent) { - return RowMajorTensorOpMultiplicandCrosswise(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return layout_.stride(); } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return layout_.stride(); } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -template -struct TensorOpMultiplicandColumnMajorInterleaved { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - // - // Static constants - // - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - - //static int const kThreadBlockStrided = ThreadBlockStrided; - static int const kInterleavedK = InterleavedK; - -private: - - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandColumnMajorInterleaved(Index ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandColumnMajorInterleaved(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandColumnMajorInterleaved packed(TensorCoord const &extent) { - return TensorOpMultiplicandColumnMajorInterleaved(extent[0] * kInterleavedK); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - int const rows_per_smem_cache_line = 128 / kInterleavedK; - - int row_id = coord.strided() / rows_per_smem_cache_line; - int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); - - int access_block_id = col_id >> 4; - int swizzle_access_block_id = access_block_id ^ (row_id & 1); - - int swizzle_col_id = swizzle_access_block_id << 4; - - return row_id * 128 + swizzle_col_id; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return (extent[1] / kInterleavedK) * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -template -struct TensorOpMultiplicandRowMajorInterleaved { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - /// This layout is optimized for 128b accesses - static int const kAccessSize = 128; - - // - // Static constants - // - - static int const kElementSize = ElementSize; - static int const kElementsPerAccess = kAccessSize / kElementSize; - - //static int const kThreadBlockStrided = ThreadBlockStrided; - static int const kInterleavedK = InterleavedK; - -private: - - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandRowMajorInterleaved(Index ldm = 0): stride_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandRowMajorInterleaved(Stride stride): stride_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandRowMajorInterleaved packed(TensorCoord const &extent) { - return TensorOpMultiplicandRowMajorInterleaved(extent[1] * kInterleavedK); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - int const rows_per_smem_cache_line = 128 / kInterleavedK; - - int row_id = coord.strided() / rows_per_smem_cache_line; - int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); - - int access_block_id = col_id >> 4; - int swizzle_access_block_id = access_block_id ^ (row_id & 1); - - int swizzle_col_id = swizzle_access_block_id << 4; - - return row_id * 128 + swizzle_col_id; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return (extent[0] / kInterleavedK) * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h deleted file mode 100644 index e3104906ee1b1d22df7f8d2822e67fd14cf4e56b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h +++ /dev/null @@ -1,1139 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief layouts needed by Ampere fp64 tensor core kernels. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace layout { - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct TensorOpMultiplicandCongruous64b { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Static constants - // - - static int const kElementSize = 64; - static int const kElementsPerAccess = 1; - - private: - - // - // Data members - // - - /// Stride data member. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous64b(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous64b(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { - return TensorOpMultiplicandCongruous64b(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - int tc = coord.contiguous() / 16; - int ts = coord.strided() / 4; - - int c = coord.contiguous() % 16; - int s = coord.strided() % 4; - - - int bank = ((((c & 1) * 4 + (c & 6) / 2)) ^ (s & 1)) * 2 + (c / 8); - int row = (c & 6) / 2; - - bank ^= ((s & 2) * 2); - - LongIndex offset = tc * 16 + bank + (ts * 4 + row) * stride_[0]; - - return offset; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } - - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - return TensorCoord(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct ColumnMajorTensorOpMultiplicandCongruous64b { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous64b; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicandCongruous64b(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct RowMajorTensorOpMultiplicandCongruous64b { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous64b; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { - return RowMajorTensorOpMultiplicandCongruous64b(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct TensorOpMultiplicand64bCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Static constants - // - - static int const kElementSize = 64; - static int const kElementsPerAccess = 1; - - private: - - // - // Data members - // - - /// Stride data member. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicand64bCrosswise(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicand64bCrosswise(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { - return TensorOpMultiplicand64bCrosswise(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - int tc = coord.contiguous() / 16; - int ts = coord.strided() / 16; - - int c = coord.contiguous() % 16; - int s = coord.strided() % 16; - - int k_group = c / 4; - int access_s = s / 2; - - int row = access_s % 4; - int bank = ((k_group & 2) << 2) ^ ((s % 2) << 3) + (c % 4) * 2 + (access_s / 4) ^ (k_group & 1); - - int smem_row = (k_group * 4 + row) + tc * 16; - int smem_col = ts * 16 + bank; - - LongIndex offset = smem_row * stride_[0] + smem_col; - - return offset; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct ColumnMajorTensorOpMultiplicand64bCrosswise { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicand64bCrosswise; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicand64bCrosswise(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct RowMajorTensorOpMultiplicand64bCrosswise { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicand64bCrosswise; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { - return RowMajorTensorOpMultiplicand64bCrosswise(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct TensorOpMultiplicandCongruous128b { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Static constants - // - - static int const kElementSize = 128; - static int const kElementsPerAccess = 1; - - private: - - // - // Data members - // - - /// Stride data member. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous128b(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCongruous128b(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { - return TensorOpMultiplicandCongruous128b(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - Index tc = coord.contiguous() / 8; - Index ts = coord.strided() / 4; - - Index c = coord.contiguous() % 8; - Index s = coord.strided() % 4; - - Index k_index = (c / 2); - - Index bank = (((c & 1) * 4) | (s ^ k_index)); - - LongIndex offset = tc * 8 + bank + (ts * 4 + k_index) * stride_[0]; - - return offset; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - return TensorCoord(); - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct ColumnMajorTensorOpMultiplicandCongruous128b { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous128b; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicandCongruous128b(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.contiguous(), coord.strided()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct RowMajorTensorOpMultiplicandCongruous128b { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCongruous128b; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { - return RowMajorTensorOpMultiplicandCongruous128b(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Inverse of layout function, mapping linear offset to logical coordinate - CUTLASS_HOST_DEVICE - TensorCoord inverse(LongIndex offset) const { - PitchLinearCoord coord = layout_.inverse(offset); - return MatrixCoord(coord.strided(), coord.contiguous()); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template based on element size (in bits) - defined in terms of pitch-linear -/// memory and Crosswise size (in elements). -struct TensorOpMultiplicandCrosswise128x4 { - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = PitchLinearCoord; - - /// Stride vector - using Stride = Coord; - - // - // Static constants - // - - static int const kElementSize = 128; - static int const kElementsPerAccess = 1; - - private: - - // - // Data members - // - - /// Stride data member. - Stride stride_; - - public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCrosswise128x4(Index ldm = 0) : stride_(ldm) {} - - /// Ctor - CUTLASS_HOST_DEVICE - TensorOpMultiplicandCrosswise128x4(Stride stride) : stride_(stride) {} - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static TensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { - return TensorOpMultiplicandCrosswise128x4(extent[0]); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - - Index tc = coord.contiguous() / 8; - Index ts = coord.strided() / 8; - - Index c = coord.contiguous() % 8; - Index s = coord.strided() % 8; - - Index liq = c % 4; - - Index bank = liq + ((s & 1) * 4) ^ (c & 4); - - Index k_index = (c & 4) + (s / 4) * 2 + ((s & 2) / 2); - - LongIndex offset = (tc * 8 + k_index) * stride_[0] + ts * 8 + bank; - - return offset; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Compute the number of contiguous elements needed to store a tensor with - /// the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return extent[1] * stride_[0]; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a column-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct ColumnMajorTensorOpMultiplicandCrosswise128x4 { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCrosswise128x4; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - ColumnMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static ColumnMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { - return ColumnMajorTensorOpMultiplicandCrosswise128x4(extent.column()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.row(), coord.column())); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Template mapping a row-major view of pitch-linear memory to -/// TensorOpMultiplicand -struct RowMajorTensorOpMultiplicandCrosswise128x4 { - - /// Logical rank of tensor - static int const kRank = 2; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = MatrixCoord; - - /// Stride vector - using Stride = Coord; - - // - // Invariants - // - - using Base = TensorOpMultiplicandCrosswise128x4; - -private: - - // - // Data members - // - - Base layout_; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } - - /// Ctor - CUTLASS_HOST_DEVICE - RowMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static RowMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { - return RowMajorTensorOpMultiplicandCrosswise128x4(extent.row()); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (contiguous, strided) - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return layout_(PitchLinearCoord(coord.column(), coord.row())); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &extent) const { - return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace layout -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/vector.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/vector.h deleted file mode 100644 index 6cb74f35ffa1ac56a4c0c9c07e888b414d1be3a1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/layout/vector.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used for rank=1 vectors. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -namespace cutlass { -namespace layout { - -/// Tensor layout for densely packed vectors. -class PackedVectorLayout { -public: - /// Logical rank of tensor - static int const kRank = 1; - - /// Rank of stride vector - static int const kStrideRank = 1; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Coord; - - /// Stride vector - using Stride = Coord; - -private: - - // - // No actual stride vector stored - // - -public: - - // - // Methods - // - - CUTLASS_HOST_DEVICE - PackedVectorLayout() { } - - /// Helper returns a layout to a tightly packed tensor - CUTLASS_HOST_DEVICE - static PackedVectorLayout packed(TensorCoord const &size) { - CUTLASS_UNUSED(size); - return PackedVectorLayout(); - } - - /// Returns the offset of a coordinate in linear memory - CUTLASS_HOST_DEVICE - LongIndex operator()(TensorCoord const &coord) const { - return coord[0]; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return make_Coord(1); - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &size) const { - return size[0]; - } -}; - -} // namespace layout -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix.h deleted file mode 100644 index 00222c128dc1216d541e7dd7341d71138cfa28a0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix.h +++ /dev/null @@ -1,14129 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - \file - \brief Matrix classes with value semantics. -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) -#include -#include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/matrix.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Primary template with partial specializations to follow -template struct Matrix; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 1-by-2 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 1; - - /// Number of columns in matrix - static int const kColumns = 2; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 2; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 1-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 1-by-2 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1 - ) { - - data[0] = _0_0; data[1] = _0_1; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x2(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x2(v, i, 0); - } - - /// Forms a 1-by-2 matrix by horizontally concatenating an Element with an Element - CUTLASS_HOST_DEVICE - static Matrix hcat(Element lhs, Element rhs) { - return Matrix( - lhs, rhs); - } - - /// Concatenates this matrix with a an Element to form a 1-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Element rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 1-by-2 matrix to form a 1-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 1-by-2 matrix to form a 2-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-2 matrix to form a 3-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-2 matrix to form a 4-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Elementwise add operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - - return result; - } - - /// Elementwise add operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - - return *this; - } - - /// Elementwise subtract operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - - return result; - } - - /// Elementwise subtract operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - - return *this; - } - - /// Elementwise multiply operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - - return result; - } - - /// Scalar multiply operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - - return result; - } - - /// Scalar multiply operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - - return *this; - } - - /// Elementwise divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - - return result; - } - - /// Scalar divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - - return result; - } - - /// Scalar divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - - return *this; - } - - /// Elementwise divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (1-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - - return m; - } - - /// Matrix product of size 1-by-1-by-2 - CUTLASS_HOST_DEVICE - Element product(Matrix const &rhs, Element accum = Element()) const { - - // k=0 - accum += data[0] * rhs.data[0]; - - // k=1 - accum += data[1] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 1-by-1-by-2 - CUTLASS_HOST_DEVICE - Element operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 1-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 1-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 1-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 1-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Dot product of vectors with extent 2 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - return accum; - } - - /// Dot product of vectors with extent 2 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - -}; - -/// Template alias for 1-by-2 matrix -template -using Matrix1x2 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( - Element _0_0, Element _0_1 -) { - return Matrix1x2( - _0_0, _0_1 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 1-by-3 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 1; - - /// Number of columns in matrix - static int const kColumns = 3; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 3; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 1-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 1-by-3 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - mt.data[2] = data[2]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x3(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x3(v, i, 0); - } - - /// Forms a 1-by-3 matrix by horizontally concatenating an Element with a 1-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Element lhs, Matrix const & rhs) { - return Matrix( - lhs, rhs.at(0, 0), rhs.at(0, 1)); - } - - /// Forms a 1-by-3 matrix by horizontally concatenating a 1-by-2 matrix with an Element - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Element rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs); - } - - /// Concatenates this matrix with a an Element to form a 1-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Element rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 1-by-3 matrix to form a 2-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-3 matrix to form a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-3 matrix to form a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Elementwise add operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - - return result; - } - - /// Elementwise add operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - - return *this; - } - - /// Elementwise subtract operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - - return result; - } - - /// Elementwise subtract operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - - return *this; - } - - /// Elementwise multiply operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - - return result; - } - - /// Scalar multiply operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - - return result; - } - - /// Scalar multiply operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - - return *this; - } - - /// Elementwise divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - - return result; - } - - /// Scalar divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - - return result; - } - - /// Scalar divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - - return *this; - } - - /// Elementwise divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (1-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - - return m; - } - - /// Matrix product of size 1-by-1-by-3 - CUTLASS_HOST_DEVICE - Element product(Matrix const &rhs, Element accum = Element()) const { - - // k=0 - accum += data[0] * rhs.data[0]; - - // k=1 - accum += data[1] * rhs.data[1]; - - // k=2 - accum += data[2] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 1-by-1-by-3 - CUTLASS_HOST_DEVICE - Element operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 1-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - - return accum; - } - - /// Matrix product of size 1-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 1-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 1-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Dot product of vectors with extent 3 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - return accum; - } - - /// Dot product of vectors with extent 3 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - - /// Cross product - CUTLASS_HOST_DEVICE - Matrix cross(Matrix const &rhs) const { - return Matrix( - data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[2] * rhs.data[0] - data[0] * rhs.data[2], - data[0] * rhs.data[1] - data[1] * rhs.data[0] - ); - } - -}; - -/// Template alias for 1-by-3 matrix -template -using Matrix1x3 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( - Element _0_0, Element _0_1, Element _0_2 -) { - return Matrix1x3( - _0_0, _0_1, _0_2 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 1-by-4 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 1; - - /// Number of columns in matrix - static int const kColumns = 4; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 4; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 1-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 1-by-4 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, Element _0_3 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - mt.data[2] = data[2]; - mt.data[3] = data[3]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 1 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x4(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x4(v, i, 0); - } - - /// Forms a 1-by-4 matrix by horizontally concatenating an Element with a 1-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Element lhs, Matrix const & rhs) { - return Matrix( - lhs, rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)); - } - - /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-2 matrix with a 1-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)); - } - - /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-3 matrix with an Element - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Element rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs); - } - - /// Concatenates this matrix with a a 1-by-4 matrix to form a 2-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-4 matrix to form a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-4 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Elementwise add operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - return result; - } - - /// Elementwise add operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - return *this; - } - - /// Elementwise subtract operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - return result; - } - - /// Elementwise subtract operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - return *this; - } - - /// Elementwise multiply operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - return result; - } - - /// Scalar multiply operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - return result; - } - - /// Scalar multiply operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - data[3] *= s; - - return *this; - } - - /// Elementwise divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - return result; - } - - /// Scalar divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - return result; - } - - /// Scalar divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - data[3] /= s; - - return *this; - } - - /// Elementwise divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (1-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - - return m; - } - - /// Matrix product of size 1-by-1-by-4 - CUTLASS_HOST_DEVICE - Element product(Matrix const &rhs, Element accum = Element()) const { - - // k=0 - accum += data[0] * rhs.data[0]; - - // k=1 - accum += data[1] * rhs.data[1]; - - // k=2 - accum += data[2] * rhs.data[2]; - - // k=3 - accum += data[3] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 1-by-1-by-4 - CUTLASS_HOST_DEVICE - Element operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - - // k=3 - accum.data[0] += data[3] * rhs.data[6]; - accum.data[1] += data[3] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 1-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - - // k=3 - accum.data[0] += data[3] * rhs.data[9]; - accum.data[1] += data[3] * rhs.data[10]; - accum.data[2] += data[3] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 1-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - - // k=3 - accum.data[0] += data[3] * rhs.data[12]; - accum.data[1] += data[3] * rhs.data[13]; - accum.data[2] += data[3] * rhs.data[14]; - accum.data[3] += data[3] * rhs.data[15]; - - return accum; - } - - /// Matrix product of size 1-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 1-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Dot product of vectors with extent 4 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - accum += data[3] * rhs.data[3]; - return accum; - } - - /// Dot product of vectors with extent 4 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - accum += data[3] * rhs.data[3]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - -}; - -/// Template alias for 1-by-4 matrix -template -using Matrix1x4 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( - Element _0_0, Element _0_1, Element _0_2, Element _0_3 -) { - return Matrix1x4( - _0_0, _0_1, _0_2, _0_3 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 2-by-1 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 2; - - /// Number of columns in matrix - static int const kColumns = 1; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 2; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 2-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 2-by-1 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, - Element _1_0 - ) { - - data[0] = _0_0; - data[1] = _1_0; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_2x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_2x1(v, 0, j); - } - - /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-3 matrix to form a 2-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 2-by-1 matrix by vertically concatenating an Element with an Element - CUTLASS_HOST_DEVICE - static Matrix vcat(Element upper, Element lower) { - return Matrix( - upper - , lower); - } - - /// Concatenates this matrix with a an Element to form a 3-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Element rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-1 matrix to form a 4-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Elementwise add operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - - result.data[1] = data[1] + rhs.data[1]; - - return result; - } - - /// Elementwise add operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - - data[1] += rhs.data[1]; - - return *this; - } - - /// Elementwise subtract operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - - result.data[1] = data[1] - rhs.data[1]; - - return result; - } - - /// Elementwise subtract operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - - data[1] -= rhs.data[1]; - - return *this; - } - - /// Elementwise multiply operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - - result.data[1] = data[1] * rhs.data[1]; - - return result; - } - - /// Scalar multiply operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - - result.data[1] = data[1] * s; - - return result; - } - - /// Scalar multiply operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - - data[1] *= s; - - return *this; - } - - /// Elementwise divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - - result.data[1] = data[1] / rhs.data[1]; - - return result; - } - - /// Scalar divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - - result.data[1] = data[1] / s; - - return result; - } - - /// Scalar divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - - data[1] /= s; - - return *this; - } - - /// Elementwise divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (2-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - - data[1] /= rhs.data[1]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - - return m; - } - - /// Matrix product of size 2-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[1] * rhs.data[0]; - - return accum; - } - - /// Matrix product of size 2-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 2-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[1] * rhs.data[0]; - accum.data[3] += data[1] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 2-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[1] * rhs.data[0]; - accum.data[4] += data[1] * rhs.data[1]; - accum.data[5] += data[1] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 2-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[1] * rhs.data[0]; - accum.data[5] += data[1] * rhs.data[1]; - accum.data[6] += data[1] * rhs.data[2]; - accum.data[7] += data[1] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 2-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Dot product of vectors with extent 2 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - return accum; - } - - /// Dot product of vectors with extent 2 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - -}; - -/// Template alias for 2-by-1 matrix -template -using Matrix2x1 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( - Element _0_0, - Element _1_0 -) { - return Matrix2x1( - _0_0, - _1_0 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 2-by-2 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 2; - - /// Number of columns in matrix - static int const kColumns = 2; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 4; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 2-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 2-by-2 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1 - ) { - - data[0] = _0_0; data[1] = _0_1; - data[2] = _1_0; data[3] = _1_1; - } - - /// Constructs a 2-by-2 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_1.data[0]; - data[3] = row_1.data[1]; - } - - /// Static method to construct a 2-by-2 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_0.data[1]; - result.data[3] = column_1.data[1]; - return result; - } - - /// Constructs an identity matrix - CUTLASS_HOST_DEVICE - static Matrix identity() { - Matrix m; - - m.data[0] = Element(1); - m.data[3] = Element(1); - - return m; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[3]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[2] = data[1]; - mt.data[1] = data[2]; - mt.data[3] = data[3]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x2(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x2(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_2x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_2x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - - return *this; - } - - /// Forms a 2-by-2 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0) - , lhs.at(1, 0), rhs.at(1, 0)); - } - - /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 2-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 1-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , lower.at(0, 0), lower.at(0, 1)); - } - - /// Concatenates this matrix with a a 1-by-2 matrix to form a 3-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-2 matrix to form a 4-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 2-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Element B, - Element C, Element D) { - return Matrix( - A, B - , C, D - ); - } - - /// Elementwise add operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - return result; - } - - /// Elementwise add operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - return *this; - } - - /// Elementwise subtract operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - return result; - } - - /// Elementwise subtract operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - return *this; - } - - /// Elementwise multiply operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - return result; - } - - /// Scalar multiply operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - return result; - } - - /// Scalar multiply operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - - data[2] *= s; - data[3] *= s; - - return *this; - } - - /// Elementwise divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - return result; - } - - /// Scalar divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - return result; - } - - /// Scalar divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - - data[2] /= s; - data[3] /= s; - - return *this; - } - - /// Elementwise divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (2-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - - return m; - } - - /// Matrix product of size 2-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[2] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[3] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 2-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[2] * rhs.data[0]; - accum.data[3] += data[2] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[3] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 2-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 2-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[2] * rhs.data[0]; - accum.data[4] += data[2] * rhs.data[1]; - accum.data[5] += data[2] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[3] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 2-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[2] * rhs.data[0]; - accum.data[5] += data[2] * rhs.data[1]; - accum.data[6] += data[2] * rhs.data[2]; - accum.data[7] += data[2] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - accum.data[6] += data[3] * rhs.data[6]; - accum.data[7] += data[3] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 2-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[3]; - - return accum; - } - - /// Returns 2-by-2 rotation matrix - CUTLASS_HOST_DEVICE - static Matrix rotation(Element theta) { - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - return Matrix( - c, -s, - s, c - ); - } - - /// Computes the determinant of a 2-by-2 matrix - CUTLASS_HOST_DEVICE - Element determinant(Element accum = Element()) const { - accum += data[0] * data[3] - data[1] * data[2]; - - return accum; - } - - /// Computes the inverse of a 2-by-2 matrix given - /// the matrix's determinant - CUTLASS_HOST_DEVICE - Matrix inverse(Element det) const { - return Matrix( - data[3], -data[1], - -data[2], data[0] - ) * (Element(1) / det); - } - - /// Computes the inverse of a 2-by-2 matrix. - CUTLASS_HOST_DEVICE - Matrix inverse() const { - return inverse(determinant()); - } - -}; - -/// Template alias for 2-by-2 matrix -template -using Matrix2x2 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1 -) { - return Matrix2x2( - _0_0, _0_1, - _1_0, _1_1 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 2-by-3 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 2; - - /// Number of columns in matrix - static int const kColumns = 3; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 6; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 2-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 2-by-3 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; - data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; - } - - /// Constructs a 2-by-3 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_1.data[0]; - data[4] = row_1.data[1]; - data[5] = row_1.data[2]; - } - - /// Static method to construct a 2-by-3 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_0.data[1]; - result.data[4] = column_1.data[1]; - result.data[5] = column_2.data[1]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[3]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[2] = data[1]; - mt.data[4] = data[2]; - mt.data[1] = data[3]; - mt.data[3] = data[4]; - mt.data[5] = data[5]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x3(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x3(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_2x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_2x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - - return *this; - } - - /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)); - } - - /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)); - } - - /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 2-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 1-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); - } - - /// Concatenates this matrix with a a 1-by-3 matrix to form a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-3 matrix to form a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 2-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1) - , C, D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 2-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B - , C.at(0, 0), C.at(0, 1), D - ); - } - - /// Elementwise add operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - - result.data[3] = data[3] + rhs.data[3]; - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - - return result; - } - - /// Elementwise add operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - - data[3] += rhs.data[3]; - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - - return *this; - } - - /// Elementwise subtract operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - - result.data[3] = data[3] - rhs.data[3]; - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - - return result; - } - - /// Elementwise subtract operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - - data[3] -= rhs.data[3]; - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - - return *this; - } - - /// Elementwise multiply operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - - result.data[3] = data[3] * rhs.data[3]; - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - - return result; - } - - /// Scalar multiply operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - - result.data[3] = data[3] * s; - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - - return result; - } - - /// Scalar multiply operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - - data[3] *= s; - data[4] *= s; - data[5] *= s; - - return *this; - } - - /// Elementwise divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - - result.data[3] = data[3] / rhs.data[3]; - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - - return result; - } - - /// Scalar divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - - result.data[3] = data[3] / s; - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - - return result; - } - - /// Scalar divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - - data[3] /= s; - data[4] /= s; - data[5] /= s; - - return *this; - } - - /// Elementwise divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (2-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - - data[3] /= rhs.data[3]; - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - - return m; - } - - /// Matrix product of size 2-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[3] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[4] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[5] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 2-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[3] * rhs.data[0]; - accum.data[3] += data[3] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[4] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[5] * rhs.data[4]; - accum.data[3] += data[5] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 2-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[0]; - accum.data[4] += data[3] * rhs.data[1]; - accum.data[5] += data[3] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[4] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[5] * rhs.data[6]; - accum.data[4] += data[5] * rhs.data[7]; - accum.data[5] += data[5] * rhs.data[8]; - - return accum; - } - - /// Matrix product of size 2-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 2-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[0]; - accum.data[5] += data[3] * rhs.data[1]; - accum.data[6] += data[3] * rhs.data[2]; - accum.data[7] += data[3] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - accum.data[6] += data[4] * rhs.data[6]; - accum.data[7] += data[4] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[5] * rhs.data[8]; - accum.data[5] += data[5] * rhs.data[9]; - accum.data[6] += data[5] * rhs.data[10]; - accum.data[7] += data[5] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 2-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[4]; - - return accum; - } - -}; - -/// Template alias for 2-by-3 matrix -template -using Matrix2x3 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2 -) { - return Matrix2x3( - _0_0, _0_1, _0_2, - _1_0, _1_1, _1_2 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 2-by-4 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 2; - - /// Number of columns in matrix - static int const kColumns = 4; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 8; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 2-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 2-by-4 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; - data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; - } - - /// Constructs a 2-by-4 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_0.data[3]; - data[4] = row_1.data[0]; - data[5] = row_1.data[1]; - data[6] = row_1.data[2]; - data[7] = row_1.data[3]; - } - - /// Static method to construct a 2-by-4 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2, - Matrix const &column_3 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_3.data[0]; - result.data[4] = column_0.data[1]; - result.data[5] = column_1.data[1]; - result.data[6] = column_2.data[1]; - result.data[7] = column_3.data[1]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[3] = diag.data[1]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[3]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[2] = data[1]; - mt.data[4] = data[2]; - mt.data[6] = data[3]; - mt.data[1] = data[4]; - mt.data[3] = data[5]; - mt.data[5] = data[6]; - mt.data[7] = data[7]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 2 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x4(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x4(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_2x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_2x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - - return *this; - } - - /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)); - } - - /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)); - } - - /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-3 matrix with a 2-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)); - } - - /// Forms a 2-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 1-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); - } - - /// Concatenates this matrix with a a 1-by-4 matrix to form a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Concatenates this matrix with a a 2-by-4 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 2-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1), B.at(0, 2) - , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) - ); - } - - /// Forms a 2-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 2-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D - ); - } - - /// Elementwise add operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - - return result; - } - - /// Elementwise add operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - - return *this; - } - - /// Elementwise subtract operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - - return result; - } - - /// Elementwise subtract operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - - return *this; - } - - /// Elementwise multiply operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - - return result; - } - - /// Scalar multiply operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - - return result; - } - - /// Scalar multiply operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - data[3] *= s; - - data[4] *= s; - data[5] *= s; - data[6] *= s; - data[7] *= s; - - return *this; - } - - /// Elementwise divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - - return result; - } - - /// Scalar divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - - return result; - } - - /// Scalar divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - data[3] /= s; - - data[4] /= s; - data[5] /= s; - data[6] /= s; - data[7] /= s; - - return *this; - } - - /// Elementwise divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (2-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - - return m; - } - - /// Matrix product of size 2-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[4] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[5] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[6] * rhs.data[2]; - - // k=3 - accum.data[0] += data[3] * rhs.data[3]; - accum.data[1] += data[7] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 2-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[4] * rhs.data[0]; - accum.data[3] += data[4] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[5] * rhs.data[2]; - accum.data[3] += data[5] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[6] * rhs.data[4]; - accum.data[3] += data[6] * rhs.data[5]; - - // k=3 - accum.data[0] += data[3] * rhs.data[6]; - accum.data[1] += data[3] * rhs.data[7]; - accum.data[2] += data[7] * rhs.data[6]; - accum.data[3] += data[7] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 2-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[0]; - accum.data[4] += data[4] * rhs.data[1]; - accum.data[5] += data[4] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[5] * rhs.data[3]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[6] * rhs.data[6]; - accum.data[4] += data[6] * rhs.data[7]; - accum.data[5] += data[6] * rhs.data[8]; - - // k=3 - accum.data[0] += data[3] * rhs.data[9]; - accum.data[1] += data[3] * rhs.data[10]; - accum.data[2] += data[3] * rhs.data[11]; - accum.data[3] += data[7] * rhs.data[9]; - accum.data[4] += data[7] * rhs.data[10]; - accum.data[5] += data[7] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 2-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[0]; - accum.data[5] += data[4] * rhs.data[1]; - accum.data[6] += data[4] * rhs.data[2]; - accum.data[7] += data[4] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - accum.data[6] += data[5] * rhs.data[6]; - accum.data[7] += data[5] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[6] * rhs.data[8]; - accum.data[5] += data[6] * rhs.data[9]; - accum.data[6] += data[6] * rhs.data[10]; - accum.data[7] += data[6] * rhs.data[11]; - - // k=3 - accum.data[0] += data[3] * rhs.data[12]; - accum.data[1] += data[3] * rhs.data[13]; - accum.data[2] += data[3] * rhs.data[14]; - accum.data[3] += data[3] * rhs.data[15]; - accum.data[4] += data[7] * rhs.data[12]; - accum.data[5] += data[7] * rhs.data[13]; - accum.data[6] += data[7] * rhs.data[14]; - accum.data[7] += data[7] * rhs.data[15]; - - return accum; - } - - /// Matrix product of size 2-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 2-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[5]; - - return accum; - } - -}; - -/// Template alias for 2-by-4 matrix -template -using Matrix2x4 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3 -) { - return Matrix2x4( - _0_0, _0_1, _0_2, _0_3, - _1_0, _1_1, _1_2, _1_3 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 3-by-1 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 3; - - /// Number of columns in matrix - static int const kColumns = 1; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 3; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 3-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 3-by-1 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, - Element _1_0, - Element _2_0 - ) { - - data[0] = _0_0; - data[1] = _1_0; - data[2] = _2_0; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - mt.data[2] = data[2]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - m.data[2] = data[i * 1 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - data[i * 1 + j + 2] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_3x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_3x1(v, 0, j); - } - - /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-3 matrix to form a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 3-by-1 matrix by vertically concatenating an Element with a 2-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Element upper, Matrix const & lower) { - return Matrix( - upper - , lower.at(0, 0) - , lower.at(1, 0)); - } - - /// Forms a 3-by-1 matrix by vertically concatenating a 2-by-1 matrix with an Element - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Element lower) { - return Matrix( - upper.at(0, 0) - , upper.at(1, 0) - , lower); - } - - /// Concatenates this matrix with a an Element to form a 4-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Element rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Elementwise add operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - - result.data[1] = data[1] + rhs.data[1]; - - result.data[2] = data[2] + rhs.data[2]; - - return result; - } - - /// Elementwise add operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - - data[1] += rhs.data[1]; - - data[2] += rhs.data[2]; - - return *this; - } - - /// Elementwise subtract operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - - result.data[1] = data[1] - rhs.data[1]; - - result.data[2] = data[2] - rhs.data[2]; - - return result; - } - - /// Elementwise subtract operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - - data[1] -= rhs.data[1]; - - data[2] -= rhs.data[2]; - - return *this; - } - - /// Elementwise multiply operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - - result.data[1] = data[1] * rhs.data[1]; - - result.data[2] = data[2] * rhs.data[2]; - - return result; - } - - /// Scalar multiply operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - - result.data[1] = data[1] * s; - - result.data[2] = data[2] * s; - - return result; - } - - /// Scalar multiply operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - - data[1] *= s; - - data[2] *= s; - - return *this; - } - - /// Elementwise divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - - result.data[1] = data[1] / rhs.data[1]; - - result.data[2] = data[2] / rhs.data[2]; - - return result; - } - - /// Scalar divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - - result.data[1] = data[1] / s; - - result.data[2] = data[2] / s; - - return result; - } - - /// Scalar divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - - data[1] /= s; - - data[2] /= s; - - return *this; - } - - /// Elementwise divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (3-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - - data[1] /= rhs.data[1]; - - data[2] /= rhs.data[2]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - - return m; - } - - /// Matrix product of size 3-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[1] * rhs.data[0]; - accum.data[2] += data[2] * rhs.data[0]; - - return accum; - } - - /// Matrix product of size 3-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 3-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[1] * rhs.data[0]; - accum.data[3] += data[1] * rhs.data[1]; - accum.data[4] += data[2] * rhs.data[0]; - accum.data[5] += data[2] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 3-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[1] * rhs.data[0]; - accum.data[4] += data[1] * rhs.data[1]; - accum.data[5] += data[1] * rhs.data[2]; - accum.data[6] += data[2] * rhs.data[0]; - accum.data[7] += data[2] * rhs.data[1]; - accum.data[8] += data[2] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 3-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[1] * rhs.data[0]; - accum.data[5] += data[1] * rhs.data[1]; - accum.data[6] += data[1] * rhs.data[2]; - accum.data[7] += data[1] * rhs.data[3]; - accum.data[8] += data[2] * rhs.data[0]; - accum.data[9] += data[2] * rhs.data[1]; - accum.data[10] += data[2] * rhs.data[2]; - accum.data[11] += data[2] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 3-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Dot product of vectors with extent 3 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - return accum; - } - - /// Dot product of vectors with extent 3 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - - /// Cross product - CUTLASS_HOST_DEVICE - Matrix cross(Matrix const &rhs) const { - return Matrix( - data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[2] * rhs.data[0] - data[0] * rhs.data[2], - data[0] * rhs.data[1] - data[1] * rhs.data[0] - ); - } - -}; - -/// Template alias for 3-by-1 matrix -template -using Matrix3x1 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( - Element _0_0, - Element _1_0, - Element _2_0 -) { - return Matrix3x1( - _0_0, - _1_0, - _2_0 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 3-by-2 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 3; - - /// Number of columns in matrix - static int const kColumns = 2; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 6; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 3-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 3-by-2 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1, - Element _2_0, Element _2_1 - ) { - - data[0] = _0_0; data[1] = _0_1; - data[2] = _1_0; data[3] = _1_1; - data[4] = _2_0; data[5] = _2_1; - } - - /// Constructs a 3-by-2 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_1.data[0]; - data[3] = row_1.data[1]; - data[4] = row_2.data[0]; - data[5] = row_2.data[1]; - } - - /// Static method to construct a 3-by-2 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_0.data[1]; - result.data[3] = column_1.data[1]; - result.data[4] = column_0.data[2]; - result.data[5] = column_1.data[2]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[4]; - diag.data[2] = data[8]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[3] = data[1]; - mt.data[1] = data[2]; - mt.data[4] = data[3]; - mt.data[2] = data[4]; - mt.data[5] = data[5]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x2(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x2(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - m.data[2] = data[i * 2 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - data[i * 2 + j + 4] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_3x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_3x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - m.data[4] = data[i * 2 + j + 4]; - m.data[5] = data[i * 2 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - data[i * 2 + j + 4] = m.data[4]; - data[i * 2 + j + 5] = m.data[5]; - - return *this; - } - - /// Forms a 3-by-2 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0) - , lhs.at(1, 0), rhs.at(1, 0) - , lhs.at(2, 0), rhs.at(2, 0)); - } - - /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 3-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 2-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , lower.at(0, 0), lower.at(0, 1) - , lower.at(1, 0), lower.at(1, 1)); - } - - /// Forms a 3-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 1-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , upper.at(1, 0), upper.at(1, 1) - , lower.at(0, 0), lower.at(0, 1)); - } - - /// Concatenates this matrix with a a 1-by-2 matrix to form a 4-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 3-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B - , C.at(0, 0), D.at(0, 0) - , C.at(1, 0), D.at(1, 0) - ); - } - - /// Forms a 3-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Element D) { - return Matrix( - A.at(0, 0), B.at(0, 0) - , A.at(1, 0), B.at(1, 0) - , C, D - ); - } - - /// Elementwise add operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - - return result; - } - - /// Elementwise add operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - - return *this; - } - - /// Elementwise subtract operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - - return result; - } - - /// Elementwise subtract operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - - return *this; - } - - /// Elementwise multiply operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - - return result; - } - - /// Scalar multiply operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - - return result; - } - - /// Scalar multiply operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - - data[2] *= s; - data[3] *= s; - - data[4] *= s; - data[5] *= s; - - return *this; - } - - /// Elementwise divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - - return result; - } - - /// Scalar divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - - return result; - } - - /// Scalar divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - - data[2] /= s; - data[3] /= s; - - data[4] /= s; - data[5] /= s; - - return *this; - } - - /// Elementwise divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (3-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - - return m; - } - - /// Matrix product of size 3-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[2] * rhs.data[0]; - accum.data[2] += data[4] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[3] * rhs.data[1]; - accum.data[2] += data[5] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 3-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[2] * rhs.data[0]; - accum.data[3] += data[2] * rhs.data[1]; - accum.data[4] += data[4] * rhs.data[0]; - accum.data[5] += data[4] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[3] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[3]; - accum.data[4] += data[5] * rhs.data[2]; - accum.data[5] += data[5] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 3-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 3-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[2] * rhs.data[0]; - accum.data[4] += data[2] * rhs.data[1]; - accum.data[5] += data[2] * rhs.data[2]; - accum.data[6] += data[4] * rhs.data[0]; - accum.data[7] += data[4] * rhs.data[1]; - accum.data[8] += data[4] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[3] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - accum.data[6] += data[5] * rhs.data[3]; - accum.data[7] += data[5] * rhs.data[4]; - accum.data[8] += data[5] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 3-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[2] * rhs.data[0]; - accum.data[5] += data[2] * rhs.data[1]; - accum.data[6] += data[2] * rhs.data[2]; - accum.data[7] += data[2] * rhs.data[3]; - accum.data[8] += data[4] * rhs.data[0]; - accum.data[9] += data[4] * rhs.data[1]; - accum.data[10] += data[4] * rhs.data[2]; - accum.data[11] += data[4] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - accum.data[6] += data[3] * rhs.data[6]; - accum.data[7] += data[3] * rhs.data[7]; - accum.data[8] += data[5] * rhs.data[4]; - accum.data[9] += data[5] * rhs.data[5]; - accum.data[10] += data[5] * rhs.data[6]; - accum.data[11] += data[5] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 3-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[3]; - - return accum; - } - -}; - -/// Template alias for 3-by-2 matrix -template -using Matrix3x2 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1, - Element _2_0, Element _2_1 -) { - return Matrix3x2( - _0_0, _0_1, - _1_0, _1_1, - _2_0, _2_1 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 3-by-3 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 3; - - /// Number of columns in matrix - static int const kColumns = 3; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 9; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 3-by-3 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2, - Element _2_0, Element _2_1, Element _2_2 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; - data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; - data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; - } - - /// Constructs a 3-by-3 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_1.data[0]; - data[4] = row_1.data[1]; - data[5] = row_1.data[2]; - data[6] = row_2.data[0]; - data[7] = row_2.data[1]; - data[8] = row_2.data[2]; - } - - /// Static method to construct a 3-by-3 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_0.data[1]; - result.data[4] = column_1.data[1]; - result.data[5] = column_2.data[1]; - result.data[6] = column_0.data[2]; - result.data[7] = column_1.data[2]; - result.data[8] = column_2.data[2]; - return result; - } - - /// Constructs an identity matrix - CUTLASS_HOST_DEVICE - static Matrix identity() { - Matrix m; - - m.data[0] = Element(1); - m.data[4] = Element(1); - m.data[8] = Element(1); - - return m; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - m.data[8] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[4]; - diag.data[2] = data[8]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[3] = data[1]; - mt.data[6] = data[2]; - mt.data[1] = data[3]; - mt.data[4] = data[4]; - mt.data[7] = data[5]; - mt.data[2] = data[6]; - mt.data[5] = data[7]; - mt.data[8] = data[8]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x3(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x3(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - m.data[2] = data[i * 3 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - data[i * 3 + j + 6] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_3x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_3x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - m.data[4] = data[i * 3 + j + 6]; - m.data[5] = data[i * 3 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - data[i * 3 + j + 6] = m.data[4]; - data[i * 3 + j + 7] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - m.data[6] = data[i * 3 + j + 6]; - m.data[7] = data[i * 3 + j + 7]; - m.data[8] = data[i * 3 + j + 8]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - data[i * 3 + j + 6] = m.data[6]; - data[i * 3 + j + 7] = m.data[7]; - data[i * 3 + j + 8] = m.data[8]; - - return *this; - } - - /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) - , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1)); - } - - /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) - , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0)); - } - - /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 3-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 2-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); - } - - /// Forms a 3-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 1-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); - } - - /// Concatenates this matrix with a a 1-by-3 matrix to form a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 3-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1) - , C.at(0, 0), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), D.at(1, 0), D.at(1, 1) - ); - } - - /// Forms a 3-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B - , C.at(0, 0), C.at(0, 1), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), D.at(1, 0) - ); - } - - /// Forms a 3-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), B.at(1, 0), B.at(1, 1) - , C, D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 3-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), B.at(1, 0) - , C.at(0, 0), C.at(0, 1), D - ); - } - - /// Elementwise add operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - - result.data[3] = data[3] + rhs.data[3]; - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - result.data[8] = data[8] + rhs.data[8]; - - return result; - } - - /// Elementwise add operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - - data[3] += rhs.data[3]; - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - data[8] += rhs.data[8]; - - return *this; - } - - /// Elementwise subtract operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - - result.data[3] = data[3] - rhs.data[3]; - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - result.data[8] = data[8] - rhs.data[8]; - - return result; - } - - /// Elementwise subtract operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - - data[3] -= rhs.data[3]; - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - data[8] -= rhs.data[8]; - - return *this; - } - - /// Elementwise multiply operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - - result.data[3] = data[3] * rhs.data[3]; - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - result.data[8] = data[8] * rhs.data[8]; - - return result; - } - - /// Scalar multiply operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - - result.data[3] = data[3] * s; - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - result.data[8] = data[8] * s; - - return result; - } - - /// Scalar multiply operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - - data[3] *= s; - data[4] *= s; - data[5] *= s; - - data[6] *= s; - data[7] *= s; - data[8] *= s; - - return *this; - } - - /// Elementwise divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - - result.data[3] = data[3] / rhs.data[3]; - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - result.data[8] = data[8] / rhs.data[8]; - - return result; - } - - /// Scalar divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - - result.data[3] = data[3] / s; - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - result.data[8] = data[8] / s; - - return result; - } - - /// Scalar divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - - data[3] /= s; - data[4] /= s; - data[5] /= s; - - data[6] /= s; - data[7] /= s; - data[8] /= s; - - return *this; - } - - /// Elementwise divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (3-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - - data[3] /= rhs.data[3]; - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - data[8] /= rhs.data[8]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - m.data[8] = -data[8]; - - return m; - } - - /// Matrix product of size 3-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[3] * rhs.data[0]; - accum.data[2] += data[6] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[4] * rhs.data[1]; - accum.data[2] += data[7] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[5] * rhs.data[2]; - accum.data[2] += data[8] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 3-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[3] * rhs.data[0]; - accum.data[3] += data[3] * rhs.data[1]; - accum.data[4] += data[6] * rhs.data[0]; - accum.data[5] += data[6] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[4] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[3]; - accum.data[4] += data[7] * rhs.data[2]; - accum.data[5] += data[7] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[5] * rhs.data[4]; - accum.data[3] += data[5] * rhs.data[5]; - accum.data[4] += data[8] * rhs.data[4]; - accum.data[5] += data[8] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 3-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[0]; - accum.data[4] += data[3] * rhs.data[1]; - accum.data[5] += data[3] * rhs.data[2]; - accum.data[6] += data[6] * rhs.data[0]; - accum.data[7] += data[6] * rhs.data[1]; - accum.data[8] += data[6] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[4] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - accum.data[6] += data[7] * rhs.data[3]; - accum.data[7] += data[7] * rhs.data[4]; - accum.data[8] += data[7] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[5] * rhs.data[6]; - accum.data[4] += data[5] * rhs.data[7]; - accum.data[5] += data[5] * rhs.data[8]; - accum.data[6] += data[8] * rhs.data[6]; - accum.data[7] += data[8] * rhs.data[7]; - accum.data[8] += data[8] * rhs.data[8]; - - return accum; - } - - /// Matrix product of size 3-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 3-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[0]; - accum.data[5] += data[3] * rhs.data[1]; - accum.data[6] += data[3] * rhs.data[2]; - accum.data[7] += data[3] * rhs.data[3]; - accum.data[8] += data[6] * rhs.data[0]; - accum.data[9] += data[6] * rhs.data[1]; - accum.data[10] += data[6] * rhs.data[2]; - accum.data[11] += data[6] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - accum.data[6] += data[4] * rhs.data[6]; - accum.data[7] += data[4] * rhs.data[7]; - accum.data[8] += data[7] * rhs.data[4]; - accum.data[9] += data[7] * rhs.data[5]; - accum.data[10] += data[7] * rhs.data[6]; - accum.data[11] += data[7] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[5] * rhs.data[8]; - accum.data[5] += data[5] * rhs.data[9]; - accum.data[6] += data[5] * rhs.data[10]; - accum.data[7] += data[5] * rhs.data[11]; - accum.data[8] += data[8] * rhs.data[8]; - accum.data[9] += data[8] * rhs.data[9]; - accum.data[10] += data[8] * rhs.data[10]; - accum.data[11] += data[8] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 3-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - accum += data[8]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - accum += data[8] * data[8]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[4]; - accum += data[8]; - - return accum; - } - - /// Returns 3-by-3 rotation matrix around the X axis - CUTLASS_HOST_DEVICE - static Matrix rotation_X(Element theta) { - Matrix m = identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(1, 1) = c; - m.at(1, 2) = -s; - m.at(2, 1) = s; - m.at(2, 2) = c; - - return m; - } - - /// Returns 3-by-3 rotation matrix around the Y axis - CUTLASS_HOST_DEVICE - static Matrix rotation_Y(Element theta) { - Matrix m = identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(0, 0) = c; - m.at(2, 0) = -s; - m.at(0, 2) = s; - m.at(2, 2) = c; - - return m; - } - - /// Returns 3-by-3 rotation matrix around the Z axis - CUTLASS_HOST_DEVICE - static Matrix rotation_Z(Element theta) { - Matrix m = Matrix::identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(0, 0) = c; - m.at(0, 1) = -s; - m.at(1, 0) = s; - m.at(1, 1) = c; - - return m; - } - - /// Returns a 3-by-3 rotation matrix around a unit-length axis - CUTLASS_HOST_DEVICE - static Matrix rotation(Element theta, Matrix const &u) { - Element x = u.data[0]; - Element y = u.data[1]; - Element z = u.data[2]; - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - Element one_minus_cos = Element(1) - fast_cos(theta); - - Matrix m; - - m.set_slice_3x3({ - c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, - y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, - z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos - }); - - return m; - } - - /// Returns a 3-by-3 reflection about the plane specified by the - /// unit-length normal vector n_unit - CUTLASS_HOST_DEVICE - static Matrix reflection(Matrix const &n_unit) { - - Element a = n_unit.data[0]; - Element b = n_unit.data[1]; - Element c = n_unit.data[2]; - - Matrix m = Matrix::identity(); - - m.set_slice_3x3({ - Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, - Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, - Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c - }); - - return m; - } - - /// Computes the determinant of a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Element determinant(Element accum = Element()) const { - - accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(2, 1), at(2, 2) }).determinant(); - accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(2, 0), at(2, 2) }).determinant(); - accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(2, 0), at(2, 1) }).determinant(); - - return accum; - } - - /// Computes the inverse of a 3-by-3 matrix given - /// the matrix's determinant - CUTLASS_HOST_DEVICE - Matrix inverse(Element det) const { - return Matrix( - at(1, 1) * at(2, 2) - at(1, 2) * at(2, 1), - at(0, 2) * at(2, 1) - at(0, 1) * at(2, 2), - at(0, 1) * at(1, 2) - at(0, 2) * at(1, 1), - - at(1, 2) * at(2, 0) - at(1, 0) * at(2, 2), - at(0, 0) * at(2, 2) - at(0, 2) * at(2, 0), - at(0, 2) * at(1, 0) - at(0, 0) * at(1, 2), - - at(1, 0) * at(2, 1) - at(1, 1) * at(2, 0), - at(0, 1) * at(2, 0) - at(0, 0) * at(2, 1), - at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0) - ) * (Element(1) / det); - } - /// Computes the inverse of a 3-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix inverse() const { - return inverse(determinant()); - } - -}; - -/// Template alias for 3-by-3 matrix -template -using Matrix3x3 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2, - Element _2_0, Element _2_1, Element _2_2 -) { - return Matrix3x3( - _0_0, _0_1, _0_2, - _1_0, _1_1, _1_2, - _2_0, _2_1, _2_2 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 3-by-4 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 3; - - /// Number of columns in matrix - static int const kColumns = 4; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 12; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 3-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 3-by-4 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3, - Element _2_0, Element _2_1, Element _2_2, Element _2_3 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; - data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; - data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; - } - - /// Constructs a 3-by-4 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_0.data[3]; - data[4] = row_1.data[0]; - data[5] = row_1.data[1]; - data[6] = row_1.data[2]; - data[7] = row_1.data[3]; - data[8] = row_2.data[0]; - data[9] = row_2.data[1]; - data[10] = row_2.data[2]; - data[11] = row_2.data[3]; - } - - /// Static method to construct a 3-by-4 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2, - Matrix const &column_3 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_3.data[0]; - result.data[4] = column_0.data[1]; - result.data[5] = column_1.data[1]; - result.data[6] = column_2.data[1]; - result.data[7] = column_3.data[1]; - result.data[8] = column_0.data[2]; - result.data[9] = column_1.data[2]; - result.data[10] = column_2.data[2]; - result.data[11] = column_3.data[2]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - m.data[8] = s; - m.data[9] = s; - m.data[10] = s; - m.data[11] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[4] = diag.data[1]; - m.data[8] = diag.data[2]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[4]; - diag.data[2] = data[8]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[3] = data[1]; - mt.data[6] = data[2]; - mt.data[9] = data[3]; - mt.data[1] = data[4]; - mt.data[4] = data[5]; - mt.data[7] = data[6]; - mt.data[10] = data[7]; - mt.data[2] = data[8]; - mt.data[5] = data[9]; - mt.data[8] = data[10]; - mt.data[11] = data[11]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 3 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x4(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x4(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - m.data[2] = data[i * 4 + j + 8]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - data[i * 4 + j + 8] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_3x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_3x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - m.data[4] = data[i * 4 + j + 8]; - m.data[5] = data[i * 4 + j + 9]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - data[i * 4 + j + 8] = m.data[4]; - data[i * 4 + j + 9] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - m.data[6] = data[i * 4 + j + 8]; - m.data[7] = data[i * 4 + j + 9]; - m.data[8] = data[i * 4 + j + 10]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - data[i * 4 + j + 8] = m.data[6]; - data[i * 4 + j + 9] = m.data[7]; - data[i * 4 + j + 10] = m.data[8]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - m.data[8] = data[i * 4 + j + 8]; - m.data[9] = data[i * 4 + j + 9]; - m.data[10] = data[i * 4 + j + 10]; - m.data[11] = data[i * 4 + j + 11]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - data[i * 4 + j + 8] = m.data[8]; - data[i * 4 + j + 9] = m.data[9]; - data[i * 4 + j + 10] = m.data[10]; - data[i * 4 + j + 11] = m.data[11]; - - return *this; - } - - /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) - , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2)); - } - - /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) - , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1)); - } - - /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-3 matrix with a 3-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) - , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0)); - } - - /// Forms a 3-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 2-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); - } - - /// Forms a 3-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 1-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); - } - - /// Concatenates this matrix with a a 1-by-4 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix vcat(Matrix const & rhs) const { - return Matrix::vcat(*this, rhs); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1), B.at(0, 2) - , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) - , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) - ); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) - ); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) - ); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) - , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) - , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) - ); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 3-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D - ); - } - - /// Elementwise add operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - - result.data[8] = data[8] + rhs.data[8]; - result.data[9] = data[9] + rhs.data[9]; - result.data[10] = data[10] + rhs.data[10]; - result.data[11] = data[11] + rhs.data[11]; - - return result; - } - - /// Elementwise add operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - - data[8] += rhs.data[8]; - data[9] += rhs.data[9]; - data[10] += rhs.data[10]; - data[11] += rhs.data[11]; - - return *this; - } - - /// Elementwise subtract operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - - result.data[8] = data[8] - rhs.data[8]; - result.data[9] = data[9] - rhs.data[9]; - result.data[10] = data[10] - rhs.data[10]; - result.data[11] = data[11] - rhs.data[11]; - - return result; - } - - /// Elementwise subtract operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - - data[8] -= rhs.data[8]; - data[9] -= rhs.data[9]; - data[10] -= rhs.data[10]; - data[11] -= rhs.data[11]; - - return *this; - } - - /// Elementwise multiply operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - - result.data[8] = data[8] * rhs.data[8]; - result.data[9] = data[9] * rhs.data[9]; - result.data[10] = data[10] * rhs.data[10]; - result.data[11] = data[11] * rhs.data[11]; - - return result; - } - - /// Scalar multiply operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - - result.data[8] = data[8] * s; - result.data[9] = data[9] * s; - result.data[10] = data[10] * s; - result.data[11] = data[11] * s; - - return result; - } - - /// Scalar multiply operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - data[3] *= s; - - data[4] *= s; - data[5] *= s; - data[6] *= s; - data[7] *= s; - - data[8] *= s; - data[9] *= s; - data[10] *= s; - data[11] *= s; - - return *this; - } - - /// Elementwise divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - - result.data[8] = data[8] / rhs.data[8]; - result.data[9] = data[9] / rhs.data[9]; - result.data[10] = data[10] / rhs.data[10]; - result.data[11] = data[11] / rhs.data[11]; - - return result; - } - - /// Scalar divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - - result.data[8] = data[8] / s; - result.data[9] = data[9] / s; - result.data[10] = data[10] / s; - result.data[11] = data[11] / s; - - return result; - } - - /// Scalar divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - data[3] /= s; - - data[4] /= s; - data[5] /= s; - data[6] /= s; - data[7] /= s; - - data[8] /= s; - data[9] /= s; - data[10] /= s; - data[11] /= s; - - return *this; - } - - /// Elementwise divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (3-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - - data[8] /= rhs.data[8]; - data[9] /= rhs.data[9]; - data[10] /= rhs.data[10]; - data[11] /= rhs.data[11]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - m.data[8] = -data[8]; - m.data[9] = -data[9]; - m.data[10] = -data[10]; - m.data[11] = -data[11]; - - return m; - } - - /// Matrix product of size 3-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[4] * rhs.data[0]; - accum.data[2] += data[8] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[5] * rhs.data[1]; - accum.data[2] += data[9] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[6] * rhs.data[2]; - accum.data[2] += data[10] * rhs.data[2]; - - // k=3 - accum.data[0] += data[3] * rhs.data[3]; - accum.data[1] += data[7] * rhs.data[3]; - accum.data[2] += data[11] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 3-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[4] * rhs.data[0]; - accum.data[3] += data[4] * rhs.data[1]; - accum.data[4] += data[8] * rhs.data[0]; - accum.data[5] += data[8] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[5] * rhs.data[2]; - accum.data[3] += data[5] * rhs.data[3]; - accum.data[4] += data[9] * rhs.data[2]; - accum.data[5] += data[9] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[6] * rhs.data[4]; - accum.data[3] += data[6] * rhs.data[5]; - accum.data[4] += data[10] * rhs.data[4]; - accum.data[5] += data[10] * rhs.data[5]; - - // k=3 - accum.data[0] += data[3] * rhs.data[6]; - accum.data[1] += data[3] * rhs.data[7]; - accum.data[2] += data[7] * rhs.data[6]; - accum.data[3] += data[7] * rhs.data[7]; - accum.data[4] += data[11] * rhs.data[6]; - accum.data[5] += data[11] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 3-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[0]; - accum.data[4] += data[4] * rhs.data[1]; - accum.data[5] += data[4] * rhs.data[2]; - accum.data[6] += data[8] * rhs.data[0]; - accum.data[7] += data[8] * rhs.data[1]; - accum.data[8] += data[8] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[5] * rhs.data[3]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - accum.data[6] += data[9] * rhs.data[3]; - accum.data[7] += data[9] * rhs.data[4]; - accum.data[8] += data[9] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[6] * rhs.data[6]; - accum.data[4] += data[6] * rhs.data[7]; - accum.data[5] += data[6] * rhs.data[8]; - accum.data[6] += data[10] * rhs.data[6]; - accum.data[7] += data[10] * rhs.data[7]; - accum.data[8] += data[10] * rhs.data[8]; - - // k=3 - accum.data[0] += data[3] * rhs.data[9]; - accum.data[1] += data[3] * rhs.data[10]; - accum.data[2] += data[3] * rhs.data[11]; - accum.data[3] += data[7] * rhs.data[9]; - accum.data[4] += data[7] * rhs.data[10]; - accum.data[5] += data[7] * rhs.data[11]; - accum.data[6] += data[11] * rhs.data[9]; - accum.data[7] += data[11] * rhs.data[10]; - accum.data[8] += data[11] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 3-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[0]; - accum.data[5] += data[4] * rhs.data[1]; - accum.data[6] += data[4] * rhs.data[2]; - accum.data[7] += data[4] * rhs.data[3]; - accum.data[8] += data[8] * rhs.data[0]; - accum.data[9] += data[8] * rhs.data[1]; - accum.data[10] += data[8] * rhs.data[2]; - accum.data[11] += data[8] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - accum.data[6] += data[5] * rhs.data[6]; - accum.data[7] += data[5] * rhs.data[7]; - accum.data[8] += data[9] * rhs.data[4]; - accum.data[9] += data[9] * rhs.data[5]; - accum.data[10] += data[9] * rhs.data[6]; - accum.data[11] += data[9] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[6] * rhs.data[8]; - accum.data[5] += data[6] * rhs.data[9]; - accum.data[6] += data[6] * rhs.data[10]; - accum.data[7] += data[6] * rhs.data[11]; - accum.data[8] += data[10] * rhs.data[8]; - accum.data[9] += data[10] * rhs.data[9]; - accum.data[10] += data[10] * rhs.data[10]; - accum.data[11] += data[10] * rhs.data[11]; - - // k=3 - accum.data[0] += data[3] * rhs.data[12]; - accum.data[1] += data[3] * rhs.data[13]; - accum.data[2] += data[3] * rhs.data[14]; - accum.data[3] += data[3] * rhs.data[15]; - accum.data[4] += data[7] * rhs.data[12]; - accum.data[5] += data[7] * rhs.data[13]; - accum.data[6] += data[7] * rhs.data[14]; - accum.data[7] += data[7] * rhs.data[15]; - accum.data[8] += data[11] * rhs.data[12]; - accum.data[9] += data[11] * rhs.data[13]; - accum.data[10] += data[11] * rhs.data[14]; - accum.data[11] += data[11] * rhs.data[15]; - - return accum; - } - - /// Matrix product of size 3-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 3-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - accum += data[8]; - accum += data[9]; - accum += data[10]; - accum += data[11]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - accum += data[8] * data[8]; - accum += data[9] * data[9]; - accum += data[10] * data[10]; - accum += data[11] * data[11]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[5]; - accum += data[10]; - - return accum; - } - -}; - -/// Template alias for 3-by-4 matrix -template -using Matrix3x4 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3, - Element _2_0, Element _2_1, Element _2_2, Element _2_3 -) { - return Matrix3x4( - _0_0, _0_1, _0_2, _0_3, - _1_0, _1_1, _1_2, _1_3, - _2_0, _2_1, _2_2, _2_3 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 4-by-1 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 4; - - /// Number of columns in matrix - static int const kColumns = 1; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 4; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 4-by-1 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 4-by-1 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, - Element _1_0, - Element _2_0, - Element _3_0 - ) { - - data[0] = _0_0; - data[1] = _1_0; - data[2] = _2_0; - data[3] = _3_0; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[1] = data[1]; - mt.data[2] = data[2]; - mt.data[3] = data[3]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - m.data[2] = data[i * 1 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - data[i * 1 + j + 2] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 1 + j + 0]; - m.data[1] = data[i * 1 + j + 1]; - m.data[2] = data[i * 1 + j + 2]; - m.data[3] = data[i * 1 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 1 + j + 0] = m.data[0]; - data[i * 1 + j + 1] = m.data[1]; - data[i * 1 + j + 2] = m.data[2]; - data[i * 1 + j + 3] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_4x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_4x1(v, 0, j); - } - - /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 4-by-3 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 4-by-1 matrix by vertically concatenating an Element with a 3-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Element upper, Matrix const & lower) { - return Matrix( - upper - , lower.at(0, 0) - , lower.at(1, 0) - , lower.at(2, 0)); - } - - /// Forms a 4-by-1 matrix by vertically concatenating a 2-by-1 matrix with a 2-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0) - , upper.at(1, 0) - , lower.at(0, 0) - , lower.at(1, 0)); - } - - /// Forms a 4-by-1 matrix by vertically concatenating a 3-by-1 matrix with an Element - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Element lower) { - return Matrix( - upper.at(0, 0) - , upper.at(1, 0) - , upper.at(2, 0) - , lower); - } - - /// Elementwise add operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - - result.data[1] = data[1] + rhs.data[1]; - - result.data[2] = data[2] + rhs.data[2]; - - result.data[3] = data[3] + rhs.data[3]; - - return result; - } - - /// Elementwise add operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - - data[1] += rhs.data[1]; - - data[2] += rhs.data[2]; - - data[3] += rhs.data[3]; - - return *this; - } - - /// Elementwise subtract operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - - result.data[1] = data[1] - rhs.data[1]; - - result.data[2] = data[2] - rhs.data[2]; - - result.data[3] = data[3] - rhs.data[3]; - - return result; - } - - /// Elementwise subtract operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - - data[1] -= rhs.data[1]; - - data[2] -= rhs.data[2]; - - data[3] -= rhs.data[3]; - - return *this; - } - - /// Elementwise multiply operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - - result.data[1] = data[1] * rhs.data[1]; - - result.data[2] = data[2] * rhs.data[2]; - - result.data[3] = data[3] * rhs.data[3]; - - return result; - } - - /// Scalar multiply operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - - result.data[1] = data[1] * s; - - result.data[2] = data[2] * s; - - result.data[3] = data[3] * s; - - return result; - } - - /// Scalar multiply operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - - data[1] *= s; - - data[2] *= s; - - data[3] *= s; - - return *this; - } - - /// Elementwise divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - - result.data[1] = data[1] / rhs.data[1]; - - result.data[2] = data[2] / rhs.data[2]; - - result.data[3] = data[3] / rhs.data[3]; - - return result; - } - - /// Scalar divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - - result.data[1] = data[1] / s; - - result.data[2] = data[2] / s; - - result.data[3] = data[3] / s; - - return result; - } - - /// Scalar divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - - data[1] /= s; - - data[2] /= s; - - data[3] /= s; - - return *this; - } - - /// Elementwise divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (4-by-1) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - - data[1] /= rhs.data[1]; - - data[2] /= rhs.data[2]; - - data[3] /= rhs.data[3]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - - return m; - } - - /// Matrix product of size 4-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[1] * rhs.data[0]; - accum.data[2] += data[2] * rhs.data[0]; - accum.data[3] += data[3] * rhs.data[0]; - - return accum; - } - - /// Matrix product of size 4-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-1-by-1 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 4-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[1] * rhs.data[0]; - accum.data[3] += data[1] * rhs.data[1]; - accum.data[4] += data[2] * rhs.data[0]; - accum.data[5] += data[2] * rhs.data[1]; - accum.data[6] += data[3] * rhs.data[0]; - accum.data[7] += data[3] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 4-by-2-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[1] * rhs.data[0]; - accum.data[4] += data[1] * rhs.data[1]; - accum.data[5] += data[1] * rhs.data[2]; - accum.data[6] += data[2] * rhs.data[0]; - accum.data[7] += data[2] * rhs.data[1]; - accum.data[8] += data[2] * rhs.data[2]; - accum.data[9] += data[3] * rhs.data[0]; - accum.data[10] += data[3] * rhs.data[1]; - accum.data[11] += data[3] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 4-by-3-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[1] * rhs.data[0]; - accum.data[5] += data[1] * rhs.data[1]; - accum.data[6] += data[1] * rhs.data[2]; - accum.data[7] += data[1] * rhs.data[3]; - accum.data[8] += data[2] * rhs.data[0]; - accum.data[9] += data[2] * rhs.data[1]; - accum.data[10] += data[2] * rhs.data[2]; - accum.data[11] += data[2] * rhs.data[3]; - accum.data[12] += data[3] * rhs.data[0]; - accum.data[13] += data[3] * rhs.data[1]; - accum.data[14] += data[3] * rhs.data[2]; - accum.data[15] += data[3] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 4-by-4-by-1 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Dot product of vectors with extent 4 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - accum += data[3] * rhs.data[3]; - return accum; - } - - /// Dot product of vectors with extent 4 - CUTLASS_HOST_DEVICE - Element dot(Matrix const &rhs, Element accum = Element()) const { - - accum += data[0] * rhs.data[0]; - accum += data[1] * rhs.data[1]; - accum += data[2] * rhs.data[2]; - accum += data[3] * rhs.data[3]; - return accum; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - - return accum; - } - -}; - -/// Template alias for 4-by-1 matrix -template -using Matrix4x1 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( - Element _0_0, - Element _1_0, - Element _2_0, - Element _3_0 -) { - return Matrix4x1( - _0_0, - _1_0, - _2_0, - _3_0 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 4-by-2 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 4; - - /// Number of columns in matrix - static int const kColumns = 2; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 8; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 4-by-2 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 4-by-2 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1, - Element _2_0, Element _2_1, - Element _3_0, Element _3_1 - ) { - - data[0] = _0_0; data[1] = _0_1; - data[2] = _1_0; data[3] = _1_1; - data[4] = _2_0; data[5] = _2_1; - data[6] = _3_0; data[7] = _3_1; - } - - /// Constructs a 4-by-2 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2, - Matrix const &row_3 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_1.data[0]; - data[3] = row_1.data[1]; - data[4] = row_2.data[0]; - data[5] = row_2.data[1]; - data[6] = row_3.data[0]; - data[7] = row_3.data[1]; - } - - /// Static method to construct a 4-by-2 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_0.data[1]; - result.data[3] = column_1.data[1]; - result.data[4] = column_0.data[2]; - result.data[5] = column_1.data[2]; - result.data[6] = column_0.data[3]; - result.data[7] = column_1.data[3]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[5]; - diag.data[2] = data[10]; - diag.data[3] = data[15]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[4] = data[1]; - mt.data[1] = data[2]; - mt.data[5] = data[3]; - mt.data[2] = data[4]; - mt.data[6] = data[5]; - mt.data[3] = data[6]; - mt.data[7] = data[7]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x2(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x2(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - m.data[2] = data[i * 2 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - data[i * 2 + j + 4] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - m.data[4] = data[i * 2 + j + 4]; - m.data[5] = data[i * 2 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - data[i * 2 + j + 4] = m.data[4]; - data[i * 2 + j + 5] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 2]; - m.data[2] = data[i * 2 + j + 4]; - m.data[3] = data[i * 2 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 2] = m.data[1]; - data[i * 2 + j + 4] = m.data[2]; - data[i * 2 + j + 6] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_4x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_4x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 2 + j + 0]; - m.data[1] = data[i * 2 + j + 1]; - m.data[2] = data[i * 2 + j + 2]; - m.data[3] = data[i * 2 + j + 3]; - m.data[4] = data[i * 2 + j + 4]; - m.data[5] = data[i * 2 + j + 5]; - m.data[6] = data[i * 2 + j + 6]; - m.data[7] = data[i * 2 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 2 + j + 0] = m.data[0]; - data[i * 2 + j + 1] = m.data[1]; - data[i * 2 + j + 2] = m.data[2]; - data[i * 2 + j + 3] = m.data[3]; - data[i * 2 + j + 4] = m.data[4]; - data[i * 2 + j + 5] = m.data[5]; - data[i * 2 + j + 6] = m.data[6]; - data[i * 2 + j + 7] = m.data[7]; - - return *this; - } - - /// Forms a 4-by-2 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0) - , lhs.at(1, 0), rhs.at(1, 0) - , lhs.at(2, 0), rhs.at(2, 0) - , lhs.at(3, 0), rhs.at(3, 0)); - } - - /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 4-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 3-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , lower.at(0, 0), lower.at(0, 1) - , lower.at(1, 0), lower.at(1, 1) - , lower.at(2, 0), lower.at(2, 1)); - } - - /// Forms a 4-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 2-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , upper.at(1, 0), upper.at(1, 1) - , lower.at(0, 0), lower.at(0, 1) - , lower.at(1, 0), lower.at(1, 1)); - } - - /// Forms a 4-by-2 matrix by vertically concatenating a 3-by-2 matrix with a 1-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1) - , upper.at(1, 0), upper.at(1, 1) - , upper.at(2, 0), upper.at(2, 1) - , lower.at(0, 0), lower.at(0, 1)); - } - - /// Forms a 4-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B - , C.at(0, 0), D.at(0, 0) - , C.at(1, 0), D.at(1, 0) - , C.at(2, 0), D.at(2, 0) - ); - } - - /// Forms a 4-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0) - , A.at(1, 0), B.at(1, 0) - , C.at(0, 0), D.at(0, 0) - , C.at(1, 0), D.at(1, 0) - ); - } - - /// Forms a 4-by-2 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Element D) { - return Matrix( - A.at(0, 0), B.at(0, 0) - , A.at(1, 0), B.at(1, 0) - , A.at(2, 0), B.at(2, 0) - , C, D - ); - } - - /// Elementwise add operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - - return result; - } - - /// Elementwise add operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - - return *this; - } - - /// Elementwise subtract operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - - return result; - } - - /// Elementwise subtract operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - - return *this; - } - - /// Elementwise multiply operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - - return result; - } - - /// Scalar multiply operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - - return result; - } - - /// Scalar multiply operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - - data[2] *= s; - data[3] *= s; - - data[4] *= s; - data[5] *= s; - - data[6] *= s; - data[7] *= s; - - return *this; - } - - /// Elementwise divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - - return result; - } - - /// Scalar divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - - return result; - } - - /// Scalar divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - - data[2] /= s; - data[3] /= s; - - data[4] /= s; - data[5] /= s; - - data[6] /= s; - data[7] /= s; - - return *this; - } - - /// Elementwise divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (4-by-2) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - - return m; - } - - /// Matrix product of size 4-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[2] * rhs.data[0]; - accum.data[2] += data[4] * rhs.data[0]; - accum.data[3] += data[6] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[3] * rhs.data[1]; - accum.data[2] += data[5] * rhs.data[1]; - accum.data[3] += data[7] * rhs.data[1]; - - return accum; - } - - /// Matrix product of size 4-by-1-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[2] * rhs.data[0]; - accum.data[3] += data[2] * rhs.data[1]; - accum.data[4] += data[4] * rhs.data[0]; - accum.data[5] += data[4] * rhs.data[1]; - accum.data[6] += data[6] * rhs.data[0]; - accum.data[7] += data[6] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[3] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[3]; - accum.data[4] += data[5] * rhs.data[2]; - accum.data[5] += data[5] * rhs.data[3]; - accum.data[6] += data[7] * rhs.data[2]; - accum.data[7] += data[7] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 4-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-2-by-2 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 4-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[2] * rhs.data[0]; - accum.data[4] += data[2] * rhs.data[1]; - accum.data[5] += data[2] * rhs.data[2]; - accum.data[6] += data[4] * rhs.data[0]; - accum.data[7] += data[4] * rhs.data[1]; - accum.data[8] += data[4] * rhs.data[2]; - accum.data[9] += data[6] * rhs.data[0]; - accum.data[10] += data[6] * rhs.data[1]; - accum.data[11] += data[6] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[3] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - accum.data[6] += data[5] * rhs.data[3]; - accum.data[7] += data[5] * rhs.data[4]; - accum.data[8] += data[5] * rhs.data[5]; - accum.data[9] += data[7] * rhs.data[3]; - accum.data[10] += data[7] * rhs.data[4]; - accum.data[11] += data[7] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 4-by-3-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[2] * rhs.data[0]; - accum.data[5] += data[2] * rhs.data[1]; - accum.data[6] += data[2] * rhs.data[2]; - accum.data[7] += data[2] * rhs.data[3]; - accum.data[8] += data[4] * rhs.data[0]; - accum.data[9] += data[4] * rhs.data[1]; - accum.data[10] += data[4] * rhs.data[2]; - accum.data[11] += data[4] * rhs.data[3]; - accum.data[12] += data[6] * rhs.data[0]; - accum.data[13] += data[6] * rhs.data[1]; - accum.data[14] += data[6] * rhs.data[2]; - accum.data[15] += data[6] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[3] * rhs.data[4]; - accum.data[5] += data[3] * rhs.data[5]; - accum.data[6] += data[3] * rhs.data[6]; - accum.data[7] += data[3] * rhs.data[7]; - accum.data[8] += data[5] * rhs.data[4]; - accum.data[9] += data[5] * rhs.data[5]; - accum.data[10] += data[5] * rhs.data[6]; - accum.data[11] += data[5] * rhs.data[7]; - accum.data[12] += data[7] * rhs.data[4]; - accum.data[13] += data[7] * rhs.data[5]; - accum.data[14] += data[7] * rhs.data[6]; - accum.data[15] += data[7] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 4-by-4-by-2 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[3]; - - return accum; - } - -}; - -/// Template alias for 4-by-2 matrix -template -using Matrix4x2 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( - Element _0_0, Element _0_1, - Element _1_0, Element _1_1, - Element _2_0, Element _2_1, - Element _3_0, Element _3_1 -) { - return Matrix4x2( - _0_0, _0_1, - _1_0, _1_1, - _2_0, _2_1, - _3_0, _3_1 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 4-by-3 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 4; - - /// Number of columns in matrix - static int const kColumns = 3; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 12; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 4-by-3 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 4-by-3 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2, - Element _2_0, Element _2_1, Element _2_2, - Element _3_0, Element _3_1, Element _3_2 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; - data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; - data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; - data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; - } - - /// Constructs a 4-by-3 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2, - Matrix const &row_3 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_1.data[0]; - data[4] = row_1.data[1]; - data[5] = row_1.data[2]; - data[6] = row_2.data[0]; - data[7] = row_2.data[1]; - data[8] = row_2.data[2]; - data[9] = row_3.data[0]; - data[10] = row_3.data[1]; - data[11] = row_3.data[2]; - } - - /// Static method to construct a 4-by-3 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_0.data[1]; - result.data[4] = column_1.data[1]; - result.data[5] = column_2.data[1]; - result.data[6] = column_0.data[2]; - result.data[7] = column_1.data[2]; - result.data[8] = column_2.data[2]; - result.data[9] = column_0.data[3]; - result.data[10] = column_1.data[3]; - result.data[11] = column_2.data[3]; - return result; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - m.data[8] = s; - m.data[9] = s; - m.data[10] = s; - m.data[11] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[5]; - diag.data[2] = data[10]; - diag.data[3] = data[15]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[4] = data[1]; - mt.data[8] = data[2]; - mt.data[1] = data[3]; - mt.data[5] = data[4]; - mt.data[9] = data[5]; - mt.data[2] = data[6]; - mt.data[6] = data[7]; - mt.data[10] = data[8]; - mt.data[3] = data[9]; - mt.data[7] = data[10]; - mt.data[11] = data[11]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x3(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x3(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - m.data[2] = data[i * 3 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - data[i * 3 + j + 6] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - m.data[4] = data[i * 3 + j + 6]; - m.data[5] = data[i * 3 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - data[i * 3 + j + 6] = m.data[4]; - data[i * 3 + j + 7] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - m.data[6] = data[i * 3 + j + 6]; - m.data[7] = data[i * 3 + j + 7]; - m.data[8] = data[i * 3 + j + 8]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - data[i * 3 + j + 6] = m.data[6]; - data[i * 3 + j + 7] = m.data[7]; - data[i * 3 + j + 8] = m.data[8]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 3]; - m.data[2] = data[i * 3 + j + 6]; - m.data[3] = data[i * 3 + j + 9]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 3] = m.data[1]; - data[i * 3 + j + 6] = m.data[2]; - data[i * 3 + j + 9] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_4x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_4x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 3]; - m.data[3] = data[i * 3 + j + 4]; - m.data[4] = data[i * 3 + j + 6]; - m.data[5] = data[i * 3 + j + 7]; - m.data[6] = data[i * 3 + j + 9]; - m.data[7] = data[i * 3 + j + 10]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 3] = m.data[2]; - data[i * 3 + j + 4] = m.data[3]; - data[i * 3 + j + 6] = m.data[4]; - data[i * 3 + j + 7] = m.data[5]; - data[i * 3 + j + 9] = m.data[6]; - data[i * 3 + j + 10] = m.data[7]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 3 + j + 0]; - m.data[1] = data[i * 3 + j + 1]; - m.data[2] = data[i * 3 + j + 2]; - m.data[3] = data[i * 3 + j + 3]; - m.data[4] = data[i * 3 + j + 4]; - m.data[5] = data[i * 3 + j + 5]; - m.data[6] = data[i * 3 + j + 6]; - m.data[7] = data[i * 3 + j + 7]; - m.data[8] = data[i * 3 + j + 8]; - m.data[9] = data[i * 3 + j + 9]; - m.data[10] = data[i * 3 + j + 10]; - m.data[11] = data[i * 3 + j + 11]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 3 + j + 0] = m.data[0]; - data[i * 3 + j + 1] = m.data[1]; - data[i * 3 + j + 2] = m.data[2]; - data[i * 3 + j + 3] = m.data[3]; - data[i * 3 + j + 4] = m.data[4]; - data[i * 3 + j + 5] = m.data[5]; - data[i * 3 + j + 6] = m.data[6]; - data[i * 3 + j + 7] = m.data[7]; - data[i * 3 + j + 8] = m.data[8]; - data[i * 3 + j + 9] = m.data[9]; - data[i * 3 + j + 10] = m.data[10]; - data[i * 3 + j + 11] = m.data[11]; - - return *this; - } - - /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) - , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1) - , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1)); - } - - /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) - , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0) - , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0)); - } - - /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix hcat(Matrix const & rhs) const { - return Matrix::hcat(*this, rhs); - } - - /// Forms a 4-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 3-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2) - , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2)); - } - - /// Forms a 4-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 2-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); - } - - /// Forms a 4-by-3 matrix by vertically concatenating a 3-by-3 matrix with a 1-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) - , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1) - , C.at(0, 0), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), D.at(1, 0), D.at(1, 1) - , C.at(2, 0), D.at(2, 0), D.at(2, 1) - ); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B - , C.at(0, 0), C.at(0, 1), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), D.at(1, 0) - , C.at(2, 0), C.at(2, 1), D.at(2, 0) - ); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), B.at(1, 0), B.at(1, 1) - , C.at(0, 0), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), D.at(1, 0), D.at(1, 1) - ); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), B.at(1, 0) - , C.at(0, 0), C.at(0, 1), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), D.at(1, 0) - ); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), B.at(1, 0), B.at(1, 1) - , A.at(2, 0), B.at(2, 0), B.at(2, 1) - , C, D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 4-by-3 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), B.at(1, 0) - , A.at(2, 0), A.at(2, 1), B.at(2, 0) - , C.at(0, 0), C.at(0, 1), D - ); - } - - /// Elementwise add operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - - result.data[3] = data[3] + rhs.data[3]; - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - result.data[8] = data[8] + rhs.data[8]; - - result.data[9] = data[9] + rhs.data[9]; - result.data[10] = data[10] + rhs.data[10]; - result.data[11] = data[11] + rhs.data[11]; - - return result; - } - - /// Elementwise add operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - - data[3] += rhs.data[3]; - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - data[8] += rhs.data[8]; - - data[9] += rhs.data[9]; - data[10] += rhs.data[10]; - data[11] += rhs.data[11]; - - return *this; - } - - /// Elementwise subtract operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - - result.data[3] = data[3] - rhs.data[3]; - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - result.data[8] = data[8] - rhs.data[8]; - - result.data[9] = data[9] - rhs.data[9]; - result.data[10] = data[10] - rhs.data[10]; - result.data[11] = data[11] - rhs.data[11]; - - return result; - } - - /// Elementwise subtract operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - - data[3] -= rhs.data[3]; - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - data[8] -= rhs.data[8]; - - data[9] -= rhs.data[9]; - data[10] -= rhs.data[10]; - data[11] -= rhs.data[11]; - - return *this; - } - - /// Elementwise multiply operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - - result.data[3] = data[3] * rhs.data[3]; - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - result.data[8] = data[8] * rhs.data[8]; - - result.data[9] = data[9] * rhs.data[9]; - result.data[10] = data[10] * rhs.data[10]; - result.data[11] = data[11] * rhs.data[11]; - - return result; - } - - /// Scalar multiply operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - - result.data[3] = data[3] * s; - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - result.data[8] = data[8] * s; - - result.data[9] = data[9] * s; - result.data[10] = data[10] * s; - result.data[11] = data[11] * s; - - return result; - } - - /// Scalar multiply operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - - data[3] *= s; - data[4] *= s; - data[5] *= s; - - data[6] *= s; - data[7] *= s; - data[8] *= s; - - data[9] *= s; - data[10] *= s; - data[11] *= s; - - return *this; - } - - /// Elementwise divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - - result.data[3] = data[3] / rhs.data[3]; - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - result.data[8] = data[8] / rhs.data[8]; - - result.data[9] = data[9] / rhs.data[9]; - result.data[10] = data[10] / rhs.data[10]; - result.data[11] = data[11] / rhs.data[11]; - - return result; - } - - /// Scalar divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - - result.data[3] = data[3] / s; - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - result.data[8] = data[8] / s; - - result.data[9] = data[9] / s; - result.data[10] = data[10] / s; - result.data[11] = data[11] / s; - - return result; - } - - /// Scalar divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - - data[3] /= s; - data[4] /= s; - data[5] /= s; - - data[6] /= s; - data[7] /= s; - data[8] /= s; - - data[9] /= s; - data[10] /= s; - data[11] /= s; - - return *this; - } - - /// Elementwise divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (4-by-3) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - - data[3] /= rhs.data[3]; - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - data[8] /= rhs.data[8]; - - data[9] /= rhs.data[9]; - data[10] /= rhs.data[10]; - data[11] /= rhs.data[11]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - m.data[8] = -data[8]; - m.data[9] = -data[9]; - m.data[10] = -data[10]; - m.data[11] = -data[11]; - - return m; - } - - /// Matrix product of size 4-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[3] * rhs.data[0]; - accum.data[2] += data[6] * rhs.data[0]; - accum.data[3] += data[9] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[4] * rhs.data[1]; - accum.data[2] += data[7] * rhs.data[1]; - accum.data[3] += data[10] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[5] * rhs.data[2]; - accum.data[2] += data[8] * rhs.data[2]; - accum.data[3] += data[11] * rhs.data[2]; - - return accum; - } - - /// Matrix product of size 4-by-1-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[3] * rhs.data[0]; - accum.data[3] += data[3] * rhs.data[1]; - accum.data[4] += data[6] * rhs.data[0]; - accum.data[5] += data[6] * rhs.data[1]; - accum.data[6] += data[9] * rhs.data[0]; - accum.data[7] += data[9] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[4] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[3]; - accum.data[4] += data[7] * rhs.data[2]; - accum.data[5] += data[7] * rhs.data[3]; - accum.data[6] += data[10] * rhs.data[2]; - accum.data[7] += data[10] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[5] * rhs.data[4]; - accum.data[3] += data[5] * rhs.data[5]; - accum.data[4] += data[8] * rhs.data[4]; - accum.data[5] += data[8] * rhs.data[5]; - accum.data[6] += data[11] * rhs.data[4]; - accum.data[7] += data[11] * rhs.data[5]; - - return accum; - } - - /// Matrix product of size 4-by-2-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[3] * rhs.data[0]; - accum.data[4] += data[3] * rhs.data[1]; - accum.data[5] += data[3] * rhs.data[2]; - accum.data[6] += data[6] * rhs.data[0]; - accum.data[7] += data[6] * rhs.data[1]; - accum.data[8] += data[6] * rhs.data[2]; - accum.data[9] += data[9] * rhs.data[0]; - accum.data[10] += data[9] * rhs.data[1]; - accum.data[11] += data[9] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[4] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - accum.data[6] += data[7] * rhs.data[3]; - accum.data[7] += data[7] * rhs.data[4]; - accum.data[8] += data[7] * rhs.data[5]; - accum.data[9] += data[10] * rhs.data[3]; - accum.data[10] += data[10] * rhs.data[4]; - accum.data[11] += data[10] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[5] * rhs.data[6]; - accum.data[4] += data[5] * rhs.data[7]; - accum.data[5] += data[5] * rhs.data[8]; - accum.data[6] += data[8] * rhs.data[6]; - accum.data[7] += data[8] * rhs.data[7]; - accum.data[8] += data[8] * rhs.data[8]; - accum.data[9] += data[11] * rhs.data[6]; - accum.data[10] += data[11] * rhs.data[7]; - accum.data[11] += data[11] * rhs.data[8]; - - return accum; - } - - /// Matrix product of size 4-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-3-by-3 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Matrix product of size 4-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[3] * rhs.data[0]; - accum.data[5] += data[3] * rhs.data[1]; - accum.data[6] += data[3] * rhs.data[2]; - accum.data[7] += data[3] * rhs.data[3]; - accum.data[8] += data[6] * rhs.data[0]; - accum.data[9] += data[6] * rhs.data[1]; - accum.data[10] += data[6] * rhs.data[2]; - accum.data[11] += data[6] * rhs.data[3]; - accum.data[12] += data[9] * rhs.data[0]; - accum.data[13] += data[9] * rhs.data[1]; - accum.data[14] += data[9] * rhs.data[2]; - accum.data[15] += data[9] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[4] * rhs.data[4]; - accum.data[5] += data[4] * rhs.data[5]; - accum.data[6] += data[4] * rhs.data[6]; - accum.data[7] += data[4] * rhs.data[7]; - accum.data[8] += data[7] * rhs.data[4]; - accum.data[9] += data[7] * rhs.data[5]; - accum.data[10] += data[7] * rhs.data[6]; - accum.data[11] += data[7] * rhs.data[7]; - accum.data[12] += data[10] * rhs.data[4]; - accum.data[13] += data[10] * rhs.data[5]; - accum.data[14] += data[10] * rhs.data[6]; - accum.data[15] += data[10] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[5] * rhs.data[8]; - accum.data[5] += data[5] * rhs.data[9]; - accum.data[6] += data[5] * rhs.data[10]; - accum.data[7] += data[5] * rhs.data[11]; - accum.data[8] += data[8] * rhs.data[8]; - accum.data[9] += data[8] * rhs.data[9]; - accum.data[10] += data[8] * rhs.data[10]; - accum.data[11] += data[8] * rhs.data[11]; - accum.data[12] += data[11] * rhs.data[8]; - accum.data[13] += data[11] * rhs.data[9]; - accum.data[14] += data[11] * rhs.data[10]; - accum.data[15] += data[11] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 4-by-4-by-3 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - accum += data[8]; - accum += data[9]; - accum += data[10]; - accum += data[11]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - accum += data[8] * data[8]; - accum += data[9] * data[9]; - accum += data[10] * data[10]; - accum += data[11] * data[11]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[4]; - accum += data[8]; - - return accum; - } - -}; - -/// Template alias for 4-by-3 matrix -template -using Matrix4x3 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( - Element _0_0, Element _0_1, Element _0_2, - Element _1_0, Element _1_1, Element _1_2, - Element _2_0, Element _2_1, Element _2_2, - Element _3_0, Element _3_1, Element _3_2 -) { - return Matrix4x3( - _0_0, _0_1, _0_2, - _1_0, _1_1, _1_2, - _2_0, _2_1, _2_2, - _3_0, _3_1, _3_2 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// 4-by-4 matrix template class definition -template -struct Matrix { - - // - // Type definitions - // - - /// Element data type - using Element = Element_; - - /// Number of rows in matrix - static int const kRows = 4; - - /// Number of columns in matrix - static int const kColumns = 4; - - /// Layout of matrix in underlying array - using Layout = layout::RowMajor; - - /// Number of elements in matrix - static int const kCount = 16; - - // - // Data members - // - - /// Elements of the matrix in row-major layout - Array data; - - // - // Methods - // - - /// Constructs a zero matrix - CUTLASS_HOST_DEVICE - Matrix() { - data.clear(); - } - - /// Copy constructor for a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Matrix(Matrix const &rhs) { - data = rhs.data; - } - - /// Constructs a 4-by-4 matrix from scalar elements - CUTLASS_HOST_DEVICE - Matrix( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3, - Element _2_0, Element _2_1, Element _2_2, Element _2_3, - Element _3_0, Element _3_1, Element _3_2, Element _3_3 - ) { - - data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; - data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; - data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; - data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; - } - - /// Constructs a 4-by-4 matrix from row vectors - CUTLASS_HOST_DEVICE - Matrix( - Matrix const &row_0, - Matrix const &row_1, - Matrix const &row_2, - Matrix const &row_3 - ) { - data[0] = row_0.data[0]; - data[1] = row_0.data[1]; - data[2] = row_0.data[2]; - data[3] = row_0.data[3]; - data[4] = row_1.data[0]; - data[5] = row_1.data[1]; - data[6] = row_1.data[2]; - data[7] = row_1.data[3]; - data[8] = row_2.data[0]; - data[9] = row_2.data[1]; - data[10] = row_2.data[2]; - data[11] = row_2.data[3]; - data[12] = row_3.data[0]; - data[13] = row_3.data[1]; - data[14] = row_3.data[2]; - data[15] = row_3.data[3]; - } - - /// Static method to construct a 4-by-4 matrix from column vectors - CUTLASS_HOST_DEVICE - static Matrix from_columns( - Matrix const &column_0, - Matrix const &column_1, - Matrix const &column_2, - Matrix const &column_3 - ) { - Matrix result; - - result.data[0] = column_0.data[0]; - result.data[1] = column_1.data[0]; - result.data[2] = column_2.data[0]; - result.data[3] = column_3.data[0]; - result.data[4] = column_0.data[1]; - result.data[5] = column_1.data[1]; - result.data[6] = column_2.data[1]; - result.data[7] = column_3.data[1]; - result.data[8] = column_0.data[2]; - result.data[9] = column_1.data[2]; - result.data[10] = column_2.data[2]; - result.data[11] = column_3.data[2]; - result.data[12] = column_0.data[3]; - result.data[13] = column_1.data[3]; - result.data[14] = column_2.data[3]; - result.data[15] = column_3.data[3]; - return result; - } - - /// Constructs an identity matrix - CUTLASS_HOST_DEVICE - static Matrix identity() { - Matrix m; - - m.data[0] = Element(1); - m.data[5] = Element(1); - m.data[10] = Element(1); - m.data[15] = Element(1); - - return m; - } - - /// Constructs a matrix from a uniform element - CUTLASS_HOST_DEVICE - static Matrix uniform(Element s) { - Matrix m; - - m.data[0] = s; - m.data[1] = s; - m.data[2] = s; - m.data[3] = s; - m.data[4] = s; - m.data[5] = s; - m.data[6] = s; - m.data[7] = s; - m.data[8] = s; - m.data[9] = s; - m.data[10] = s; - m.data[11] = s; - m.data[12] = s; - m.data[13] = s; - m.data[14] = s; - m.data[15] = s; - - return m; - } - - /// Constructs a matrix from a uniform element 1 - CUTLASS_HOST_DEVICE - static Matrix ones() { - return uniform(Element(1)); - } - - /// Constructs a matrix from a uniform element 0 - CUTLASS_HOST_DEVICE - static Matrix zero() { - return Matrix(); - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Constructs a matrix from elements along its diagonal - CUTLASS_HOST_DEVICE - static Matrix from_diagonal(Matrix const &diag) { - Matrix m; - - m.data[0] = diag.data[0]; - m.data[5] = diag.data[1]; - m.data[10] = diag.data[2]; - m.data[15] = diag.data[3]; - - return m; - } - - /// Gets an array of diagonal elements - CUTLASS_HOST_DEVICE - Matrix diagonal() const { - Matrix diag; - - diag.data[0] = data[0]; - diag.data[1] = data[5]; - diag.data[2] = data[10]; - diag.data[3] = data[15]; - - return diag; - } - - /// Returns a transposed matrix - CUTLASS_HOST_DEVICE - Matrix transpose() const { - Matrix mt; - - mt.data[0] = data[0]; - mt.data[4] = data[1]; - mt.data[8] = data[2]; - mt.data[12] = data[3]; - mt.data[1] = data[4]; - mt.data[5] = data[5]; - mt.data[9] = data[6]; - mt.data[13] = data[7]; - mt.data[2] = data[8]; - mt.data[6] = data[9]; - mt.data[10] = data[10]; - mt.data[14] = data[11]; - mt.data[3] = data[12]; - mt.data[7] = data[13]; - mt.data[11] = data[14]; - mt.data[15] = data[15]; - - return mt; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(int i, int j) const { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(int i, int j) { - return data[i * 4 + j]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element at(Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & at(Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element &at(int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element at(int offset) const { - return data[offset]; - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element operator[](Coord<2> const &coord) const { - return at(coord[0], coord[1]); - } - - /// Accesses an element by coordinate - CUTLASS_HOST_DEVICE - Element & operator[](Coord<2> const &coord) { - return at(coord[0], coord[1]); - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element & operator[](int offset) { - return data[offset]; - } - - /// Accesses an element by offset - CUTLASS_HOST_DEVICE - Element operator[](int offset) const { - return data[offset]; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_1x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix row(int i) const { - return slice_1x4(i, 0); - } - - CUTLASS_HOST_DEVICE - Matrix &set_row(Matrix const &v, int i = 0) { - return set_slice_1x4(v, i, 0); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_2x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - m.data[2] = data[i * 4 + j + 8]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - data[i * 4 + j + 8] = m.data[2]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - m.data[4] = data[i * 4 + j + 8]; - m.data[5] = data[i * 4 + j + 9]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - data[i * 4 + j + 8] = m.data[4]; - data[i * 4 + j + 9] = m.data[5]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - m.data[6] = data[i * 4 + j + 8]; - m.data[7] = data[i * 4 + j + 9]; - m.data[8] = data[i * 4 + j + 10]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - data[i * 4 + j + 8] = m.data[6]; - data[i * 4 + j + 9] = m.data[7]; - data[i * 4 + j + 10] = m.data[8]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_3x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - m.data[8] = data[i * 4 + j + 8]; - m.data[9] = data[i * 4 + j + 9]; - m.data[10] = data[i * 4 + j + 10]; - m.data[11] = data[i * 4 + j + 11]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - data[i * 4 + j + 8] = m.data[8]; - data[i * 4 + j + 9] = m.data[9]; - data[i * 4 + j + 10] = m.data[10]; - data[i * 4 + j + 11] = m.data[11]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x1(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 4]; - m.data[2] = data[i * 4 + j + 8]; - m.data[3] = data[i * 4 + j + 12]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 4] = m.data[1]; - data[i * 4 + j + 8] = m.data[2]; - data[i * 4 + j + 12] = m.data[3]; - - return *this; - } - - CUTLASS_HOST_DEVICE - Matrix column(int j) const { - return slice_4x1(0, j); - } - - CUTLASS_HOST_DEVICE - Matrix &set_column(Matrix const &v, int j =0) { - return set_slice_4x1(v, 0, j); - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x2(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 4]; - m.data[3] = data[i * 4 + j + 5]; - m.data[4] = data[i * 4 + j + 8]; - m.data[5] = data[i * 4 + j + 9]; - m.data[6] = data[i * 4 + j + 12]; - m.data[7] = data[i * 4 + j + 13]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 4] = m.data[2]; - data[i * 4 + j + 5] = m.data[3]; - data[i * 4 + j + 8] = m.data[4]; - data[i * 4 + j + 9] = m.data[5]; - data[i * 4 + j + 12] = m.data[6]; - data[i * 4 + j + 13] = m.data[7]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x3(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 4]; - m.data[4] = data[i * 4 + j + 5]; - m.data[5] = data[i * 4 + j + 6]; - m.data[6] = data[i * 4 + j + 8]; - m.data[7] = data[i * 4 + j + 9]; - m.data[8] = data[i * 4 + j + 10]; - m.data[9] = data[i * 4 + j + 12]; - m.data[10] = data[i * 4 + j + 13]; - m.data[11] = data[i * 4 + j + 14]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 4] = m.data[3]; - data[i * 4 + j + 5] = m.data[4]; - data[i * 4 + j + 6] = m.data[5]; - data[i * 4 + j + 8] = m.data[6]; - data[i * 4 + j + 9] = m.data[7]; - data[i * 4 + j + 10] = m.data[8]; - data[i * 4 + j + 12] = m.data[9]; - data[i * 4 + j + 13] = m.data[10]; - data[i * 4 + j + 14] = m.data[11]; - - return *this; - } - - /// Gets a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix slice_4x4(int i = 0, int j = 0) const { - Matrix m; - - m.data[0] = data[i * 4 + j + 0]; - m.data[1] = data[i * 4 + j + 1]; - m.data[2] = data[i * 4 + j + 2]; - m.data[3] = data[i * 4 + j + 3]; - m.data[4] = data[i * 4 + j + 4]; - m.data[5] = data[i * 4 + j + 5]; - m.data[6] = data[i * 4 + j + 6]; - m.data[7] = data[i * 4 + j + 7]; - m.data[8] = data[i * 4 + j + 8]; - m.data[9] = data[i * 4 + j + 9]; - m.data[10] = data[i * 4 + j + 10]; - m.data[11] = data[i * 4 + j + 11]; - m.data[12] = data[i * 4 + j + 12]; - m.data[13] = data[i * 4 + j + 13]; - m.data[14] = data[i * 4 + j + 14]; - m.data[15] = data[i * 4 + j + 15]; - - return m; - } - - /// Overwrites a submatrix with optional offset - CUTLASS_HOST_DEVICE - Matrix & set_slice_4x4(Matrix const &m, int i = 0, int j = 0) { - - data[i * 4 + j + 0] = m.data[0]; - data[i * 4 + j + 1] = m.data[1]; - data[i * 4 + j + 2] = m.data[2]; - data[i * 4 + j + 3] = m.data[3]; - data[i * 4 + j + 4] = m.data[4]; - data[i * 4 + j + 5] = m.data[5]; - data[i * 4 + j + 6] = m.data[6]; - data[i * 4 + j + 7] = m.data[7]; - data[i * 4 + j + 8] = m.data[8]; - data[i * 4 + j + 9] = m.data[9]; - data[i * 4 + j + 10] = m.data[10]; - data[i * 4 + j + 11] = m.data[11]; - data[i * 4 + j + 12] = m.data[12]; - data[i * 4 + j + 13] = m.data[13]; - data[i * 4 + j + 14] = m.data[14]; - data[i * 4 + j + 15] = m.data[15]; - - return *this; - } - - /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-3 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) - , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) - , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2) - , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1), rhs.at(3, 2)); - } - - /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-2 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) - , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) - , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1) - , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0), rhs.at(3, 1)); - } - - /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-3 matrix with a 4-by-1 matrix - CUTLASS_HOST_DEVICE - static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { - return Matrix( - lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) - , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) - , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0) - , lhs.at(3, 0), lhs.at(3, 1), lhs.at(3, 2), rhs.at(3, 0)); - } - - /// Forms a 4-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 3-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3) - , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2), lower.at(2, 3)); - } - - /// Forms a 4-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 2-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) - , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); - } - - /// Forms a 4-by-4 matrix by vertically concatenating a 3-by-4 matrix with a 1-by-4 matrix - CUTLASS_HOST_DEVICE - static Matrix vcat(Matrix const & upper, Matrix const & lower) { - return Matrix( - upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) - , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) - , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2), upper.at(2, 3) - , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Element A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A, B.at(0, 0), B.at(0, 1), B.at(0, 2) - , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) - , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) - , C.at(2, 0), D.at(2, 0), D.at(2, 1), D.at(2, 2) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) - , C.at(2, 0), C.at(2, 1), D.at(2, 0), D.at(2, 1) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Element B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) - , C.at(2, 0), C.at(2, 1), C.at(2, 2), D.at(2, 0) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) - , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) - , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) - , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) - , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Element C, Matrix const & D) { - return Matrix( - A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) - , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) - , A.at(2, 0), B.at(2, 0), B.at(2, 1), B.at(2, 2) - , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Matrix const & D) { - return Matrix( - A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) - , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) - , A.at(2, 0), A.at(2, 1), B.at(2, 0), B.at(2, 1) - , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) - ); - } - - /// Forms a 4-by-4 matrix by concatenating four components - CUTLASS_HOST_DEVICE - static Matrix block( - Matrix const & A, Matrix const & B, - Matrix const & C, Element D) { - return Matrix( - A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) - , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) - , A.at(2, 0), A.at(2, 1), A.at(2, 2), B.at(2, 0) - , C.at(0, 0), C.at(0, 1), C.at(0, 2), D - ); - } - - /// Elementwise add operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix add(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] + rhs.data[0]; - result.data[1] = data[1] + rhs.data[1]; - result.data[2] = data[2] + rhs.data[2]; - result.data[3] = data[3] + rhs.data[3]; - - result.data[4] = data[4] + rhs.data[4]; - result.data[5] = data[5] + rhs.data[5]; - result.data[6] = data[6] + rhs.data[6]; - result.data[7] = data[7] + rhs.data[7]; - - result.data[8] = data[8] + rhs.data[8]; - result.data[9] = data[9] + rhs.data[9]; - result.data[10] = data[10] + rhs.data[10]; - result.data[11] = data[11] + rhs.data[11]; - - result.data[12] = data[12] + rhs.data[12]; - result.data[13] = data[13] + rhs.data[13]; - result.data[14] = data[14] + rhs.data[14]; - result.data[15] = data[15] + rhs.data[15]; - - return result; - } - - /// Elementwise add operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix operator +(Matrix const &rhs) const { - return add(rhs); - } - - /// Elementwise add operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator +=(Matrix const &rhs) { - - data[0] += rhs.data[0]; - data[1] += rhs.data[1]; - data[2] += rhs.data[2]; - data[3] += rhs.data[3]; - - data[4] += rhs.data[4]; - data[5] += rhs.data[5]; - data[6] += rhs.data[6]; - data[7] += rhs.data[7]; - - data[8] += rhs.data[8]; - data[9] += rhs.data[9]; - data[10] += rhs.data[10]; - data[11] += rhs.data[11]; - - data[12] += rhs.data[12]; - data[13] += rhs.data[13]; - data[14] += rhs.data[14]; - data[15] += rhs.data[15]; - - return *this; - } - - /// Elementwise subtract operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix subtract(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] - rhs.data[0]; - result.data[1] = data[1] - rhs.data[1]; - result.data[2] = data[2] - rhs.data[2]; - result.data[3] = data[3] - rhs.data[3]; - - result.data[4] = data[4] - rhs.data[4]; - result.data[5] = data[5] - rhs.data[5]; - result.data[6] = data[6] - rhs.data[6]; - result.data[7] = data[7] - rhs.data[7]; - - result.data[8] = data[8] - rhs.data[8]; - result.data[9] = data[9] - rhs.data[9]; - result.data[10] = data[10] - rhs.data[10]; - result.data[11] = data[11] - rhs.data[11]; - - result.data[12] = data[12] - rhs.data[12]; - result.data[13] = data[13] - rhs.data[13]; - result.data[14] = data[14] - rhs.data[14]; - result.data[15] = data[15] - rhs.data[15]; - - return result; - } - - /// Elementwise subtract operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix operator -(Matrix const &rhs) const { - return subtract(rhs); - } - - /// Elementwise subtract operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator -=(Matrix const &rhs) { - - data[0] -= rhs.data[0]; - data[1] -= rhs.data[1]; - data[2] -= rhs.data[2]; - data[3] -= rhs.data[3]; - - data[4] -= rhs.data[4]; - data[5] -= rhs.data[5]; - data[6] -= rhs.data[6]; - data[7] -= rhs.data[7]; - - data[8] -= rhs.data[8]; - data[9] -= rhs.data[9]; - data[10] -= rhs.data[10]; - data[11] -= rhs.data[11]; - - data[12] -= rhs.data[12]; - data[13] -= rhs.data[13]; - data[14] -= rhs.data[14]; - data[15] -= rhs.data[15]; - - return *this; - } - - /// Elementwise multiply operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] * rhs.data[0]; - result.data[1] = data[1] * rhs.data[1]; - result.data[2] = data[2] * rhs.data[2]; - result.data[3] = data[3] * rhs.data[3]; - - result.data[4] = data[4] * rhs.data[4]; - result.data[5] = data[5] * rhs.data[5]; - result.data[6] = data[6] * rhs.data[6]; - result.data[7] = data[7] * rhs.data[7]; - - result.data[8] = data[8] * rhs.data[8]; - result.data[9] = data[9] * rhs.data[9]; - result.data[10] = data[10] * rhs.data[10]; - result.data[11] = data[11] * rhs.data[11]; - - result.data[12] = data[12] * rhs.data[12]; - result.data[13] = data[13] * rhs.data[13]; - result.data[14] = data[14] * rhs.data[14]; - result.data[15] = data[15] * rhs.data[15]; - - return result; - } - - /// Scalar multiply operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix multiply(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] * s; - result.data[1] = data[1] * s; - result.data[2] = data[2] * s; - result.data[3] = data[3] * s; - - result.data[4] = data[4] * s; - result.data[5] = data[5] * s; - result.data[6] = data[6] * s; - result.data[7] = data[7] * s; - - result.data[8] = data[8] * s; - result.data[9] = data[9] * s; - result.data[10] = data[10] * s; - result.data[11] = data[11] * s; - - result.data[12] = data[12] * s; - result.data[13] = data[13] * s; - result.data[14] = data[14] * s; - result.data[15] = data[15] * s; - - return result; - } - - /// Scalar multiply operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix operator *(Element const &s) const { - return multiply(s); - } - - /// Scalar multiply operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator *=(Element const &s) { - - data[0] *= s; - data[1] *= s; - data[2] *= s; - data[3] *= s; - - data[4] *= s; - data[5] *= s; - data[6] *= s; - data[7] *= s; - - data[8] *= s; - data[9] *= s; - data[10] *= s; - data[11] *= s; - - data[12] *= s; - data[13] *= s; - data[14] *= s; - data[15] *= s; - - return *this; - } - - /// Elementwise divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Matrix const &rhs) const { - - Matrix result; - - result.data[0] = data[0] / rhs.data[0]; - result.data[1] = data[1] / rhs.data[1]; - result.data[2] = data[2] / rhs.data[2]; - result.data[3] = data[3] / rhs.data[3]; - - result.data[4] = data[4] / rhs.data[4]; - result.data[5] = data[5] / rhs.data[5]; - result.data[6] = data[6] / rhs.data[6]; - result.data[7] = data[7] / rhs.data[7]; - - result.data[8] = data[8] / rhs.data[8]; - result.data[9] = data[9] / rhs.data[9]; - result.data[10] = data[10] / rhs.data[10]; - result.data[11] = data[11] / rhs.data[11]; - - result.data[12] = data[12] / rhs.data[12]; - result.data[13] = data[13] / rhs.data[13]; - result.data[14] = data[14] / rhs.data[14]; - result.data[15] = data[15] / rhs.data[15]; - - return result; - } - - /// Scalar divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix divide(Element const &s) const { - - Matrix result; - - result.data[0] = data[0] / s; - result.data[1] = data[1] / s; - result.data[2] = data[2] / s; - result.data[3] = data[3] / s; - - result.data[4] = data[4] / s; - result.data[5] = data[5] / s; - result.data[6] = data[6] / s; - result.data[7] = data[7] / s; - - result.data[8] = data[8] / s; - result.data[9] = data[9] / s; - result.data[10] = data[10] / s; - result.data[11] = data[11] / s; - - result.data[12] = data[12] / s; - result.data[13] = data[13] / s; - result.data[14] = data[14] / s; - result.data[15] = data[15] / s; - - return result; - } - - /// Scalar divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Element const &s) const { - return divide(s); - } - - /// Scalar divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Element const &s) { - - data[0] /= s; - data[1] /= s; - data[2] /= s; - data[3] /= s; - - data[4] /= s; - data[5] /= s; - data[6] /= s; - data[7] /= s; - - data[8] /= s; - data[9] /= s; - data[10] /= s; - data[11] /= s; - - data[12] /= s; - data[13] /= s; - data[14] /= s; - data[15] /= s; - - return *this; - } - - /// Elementwise divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix operator /(Matrix const &rhs) const { - return divide(rhs); - } - - /// Elementwise divide operator (4-by-4) - CUTLASS_HOST_DEVICE - Matrix & operator /=(Matrix const &rhs) { - - data[0] /= rhs.data[0]; - data[1] /= rhs.data[1]; - data[2] /= rhs.data[2]; - data[3] /= rhs.data[3]; - - data[4] /= rhs.data[4]; - data[5] /= rhs.data[5]; - data[6] /= rhs.data[6]; - data[7] /= rhs.data[7]; - - data[8] /= rhs.data[8]; - data[9] /= rhs.data[9]; - data[10] /= rhs.data[10]; - data[11] /= rhs.data[11]; - - data[12] /= rhs.data[12]; - data[13] /= rhs.data[13]; - data[14] /= rhs.data[14]; - data[15] /= rhs.data[15]; - - return *this; - } - - /// Negates each element of the matrix - CUTLASS_HOST_DEVICE - Matrix operator-() const { - Matrix m; - - m.data[0] = -data[0]; - m.data[1] = -data[1]; - m.data[2] = -data[2]; - m.data[3] = -data[3]; - m.data[4] = -data[4]; - m.data[5] = -data[5]; - m.data[6] = -data[6]; - m.data[7] = -data[7]; - m.data[8] = -data[8]; - m.data[9] = -data[9]; - m.data[10] = -data[10]; - m.data[11] = -data[11]; - m.data[12] = -data[12]; - m.data[13] = -data[13]; - m.data[14] = -data[14]; - m.data[15] = -data[15]; - - return m; - } - - /// Matrix product of size 4-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[4] * rhs.data[0]; - accum.data[2] += data[8] * rhs.data[0]; - accum.data[3] += data[12] * rhs.data[0]; - - // k=1 - accum.data[0] += data[1] * rhs.data[1]; - accum.data[1] += data[5] * rhs.data[1]; - accum.data[2] += data[9] * rhs.data[1]; - accum.data[3] += data[13] * rhs.data[1]; - - // k=2 - accum.data[0] += data[2] * rhs.data[2]; - accum.data[1] += data[6] * rhs.data[2]; - accum.data[2] += data[10] * rhs.data[2]; - accum.data[3] += data[14] * rhs.data[2]; - - // k=3 - accum.data[0] += data[3] * rhs.data[3]; - accum.data[1] += data[7] * rhs.data[3]; - accum.data[2] += data[11] * rhs.data[3]; - accum.data[3] += data[15] * rhs.data[3]; - - return accum; - } - - /// Matrix product of size 4-by-1-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[4] * rhs.data[0]; - accum.data[3] += data[4] * rhs.data[1]; - accum.data[4] += data[8] * rhs.data[0]; - accum.data[5] += data[8] * rhs.data[1]; - accum.data[6] += data[12] * rhs.data[0]; - accum.data[7] += data[12] * rhs.data[1]; - - // k=1 - accum.data[0] += data[1] * rhs.data[2]; - accum.data[1] += data[1] * rhs.data[3]; - accum.data[2] += data[5] * rhs.data[2]; - accum.data[3] += data[5] * rhs.data[3]; - accum.data[4] += data[9] * rhs.data[2]; - accum.data[5] += data[9] * rhs.data[3]; - accum.data[6] += data[13] * rhs.data[2]; - accum.data[7] += data[13] * rhs.data[3]; - - // k=2 - accum.data[0] += data[2] * rhs.data[4]; - accum.data[1] += data[2] * rhs.data[5]; - accum.data[2] += data[6] * rhs.data[4]; - accum.data[3] += data[6] * rhs.data[5]; - accum.data[4] += data[10] * rhs.data[4]; - accum.data[5] += data[10] * rhs.data[5]; - accum.data[6] += data[14] * rhs.data[4]; - accum.data[7] += data[14] * rhs.data[5]; - - // k=3 - accum.data[0] += data[3] * rhs.data[6]; - accum.data[1] += data[3] * rhs.data[7]; - accum.data[2] += data[7] * rhs.data[6]; - accum.data[3] += data[7] * rhs.data[7]; - accum.data[4] += data[11] * rhs.data[6]; - accum.data[5] += data[11] * rhs.data[7]; - accum.data[6] += data[15] * rhs.data[6]; - accum.data[7] += data[15] * rhs.data[7]; - - return accum; - } - - /// Matrix product of size 4-by-2-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[4] * rhs.data[0]; - accum.data[4] += data[4] * rhs.data[1]; - accum.data[5] += data[4] * rhs.data[2]; - accum.data[6] += data[8] * rhs.data[0]; - accum.data[7] += data[8] * rhs.data[1]; - accum.data[8] += data[8] * rhs.data[2]; - accum.data[9] += data[12] * rhs.data[0]; - accum.data[10] += data[12] * rhs.data[1]; - accum.data[11] += data[12] * rhs.data[2]; - - // k=1 - accum.data[0] += data[1] * rhs.data[3]; - accum.data[1] += data[1] * rhs.data[4]; - accum.data[2] += data[1] * rhs.data[5]; - accum.data[3] += data[5] * rhs.data[3]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - accum.data[6] += data[9] * rhs.data[3]; - accum.data[7] += data[9] * rhs.data[4]; - accum.data[8] += data[9] * rhs.data[5]; - accum.data[9] += data[13] * rhs.data[3]; - accum.data[10] += data[13] * rhs.data[4]; - accum.data[11] += data[13] * rhs.data[5]; - - // k=2 - accum.data[0] += data[2] * rhs.data[6]; - accum.data[1] += data[2] * rhs.data[7]; - accum.data[2] += data[2] * rhs.data[8]; - accum.data[3] += data[6] * rhs.data[6]; - accum.data[4] += data[6] * rhs.data[7]; - accum.data[5] += data[6] * rhs.data[8]; - accum.data[6] += data[10] * rhs.data[6]; - accum.data[7] += data[10] * rhs.data[7]; - accum.data[8] += data[10] * rhs.data[8]; - accum.data[9] += data[14] * rhs.data[6]; - accum.data[10] += data[14] * rhs.data[7]; - accum.data[11] += data[14] * rhs.data[8]; - - // k=3 - accum.data[0] += data[3] * rhs.data[9]; - accum.data[1] += data[3] * rhs.data[10]; - accum.data[2] += data[3] * rhs.data[11]; - accum.data[3] += data[7] * rhs.data[9]; - accum.data[4] += data[7] * rhs.data[10]; - accum.data[5] += data[7] * rhs.data[11]; - accum.data[6] += data[11] * rhs.data[9]; - accum.data[7] += data[11] * rhs.data[10]; - accum.data[8] += data[11] * rhs.data[11]; - accum.data[9] += data[15] * rhs.data[9]; - accum.data[10] += data[15] * rhs.data[10]; - accum.data[11] += data[15] * rhs.data[11]; - - return accum; - } - - /// Matrix product of size 4-by-3-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix product( - Matrix const &rhs, - Matrix accum = Matrix() - ) const { - - // k=0 - accum.data[0] += data[0] * rhs.data[0]; - accum.data[1] += data[0] * rhs.data[1]; - accum.data[2] += data[0] * rhs.data[2]; - accum.data[3] += data[0] * rhs.data[3]; - accum.data[4] += data[4] * rhs.data[0]; - accum.data[5] += data[4] * rhs.data[1]; - accum.data[6] += data[4] * rhs.data[2]; - accum.data[7] += data[4] * rhs.data[3]; - accum.data[8] += data[8] * rhs.data[0]; - accum.data[9] += data[8] * rhs.data[1]; - accum.data[10] += data[8] * rhs.data[2]; - accum.data[11] += data[8] * rhs.data[3]; - accum.data[12] += data[12] * rhs.data[0]; - accum.data[13] += data[12] * rhs.data[1]; - accum.data[14] += data[12] * rhs.data[2]; - accum.data[15] += data[12] * rhs.data[3]; - - // k=1 - accum.data[0] += data[1] * rhs.data[4]; - accum.data[1] += data[1] * rhs.data[5]; - accum.data[2] += data[1] * rhs.data[6]; - accum.data[3] += data[1] * rhs.data[7]; - accum.data[4] += data[5] * rhs.data[4]; - accum.data[5] += data[5] * rhs.data[5]; - accum.data[6] += data[5] * rhs.data[6]; - accum.data[7] += data[5] * rhs.data[7]; - accum.data[8] += data[9] * rhs.data[4]; - accum.data[9] += data[9] * rhs.data[5]; - accum.data[10] += data[9] * rhs.data[6]; - accum.data[11] += data[9] * rhs.data[7]; - accum.data[12] += data[13] * rhs.data[4]; - accum.data[13] += data[13] * rhs.data[5]; - accum.data[14] += data[13] * rhs.data[6]; - accum.data[15] += data[13] * rhs.data[7]; - - // k=2 - accum.data[0] += data[2] * rhs.data[8]; - accum.data[1] += data[2] * rhs.data[9]; - accum.data[2] += data[2] * rhs.data[10]; - accum.data[3] += data[2] * rhs.data[11]; - accum.data[4] += data[6] * rhs.data[8]; - accum.data[5] += data[6] * rhs.data[9]; - accum.data[6] += data[6] * rhs.data[10]; - accum.data[7] += data[6] * rhs.data[11]; - accum.data[8] += data[10] * rhs.data[8]; - accum.data[9] += data[10] * rhs.data[9]; - accum.data[10] += data[10] * rhs.data[10]; - accum.data[11] += data[10] * rhs.data[11]; - accum.data[12] += data[14] * rhs.data[8]; - accum.data[13] += data[14] * rhs.data[9]; - accum.data[14] += data[14] * rhs.data[10]; - accum.data[15] += data[14] * rhs.data[11]; - - // k=3 - accum.data[0] += data[3] * rhs.data[12]; - accum.data[1] += data[3] * rhs.data[13]; - accum.data[2] += data[3] * rhs.data[14]; - accum.data[3] += data[3] * rhs.data[15]; - accum.data[4] += data[7] * rhs.data[12]; - accum.data[5] += data[7] * rhs.data[13]; - accum.data[6] += data[7] * rhs.data[14]; - accum.data[7] += data[7] * rhs.data[15]; - accum.data[8] += data[11] * rhs.data[12]; - accum.data[9] += data[11] * rhs.data[13]; - accum.data[10] += data[11] * rhs.data[14]; - accum.data[11] += data[11] * rhs.data[15]; - accum.data[12] += data[15] * rhs.data[12]; - accum.data[13] += data[15] * rhs.data[13]; - accum.data[14] += data[15] * rhs.data[14]; - accum.data[15] += data[15] * rhs.data[15]; - - return accum; - } - - /// Matrix product of size 4-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix operator*(Matrix const &rhs) const { - return product(rhs); - } - - /// Matrix product of size 4-by-4-by-4 - CUTLASS_HOST_DEVICE - Matrix & operator*=(Matrix const &rhs) { - *this = product(rhs); - return *this; - } - - /// Returns the sum of elements - CUTLASS_HOST_DEVICE - Element sum(Element accum = Element()) const { - - accum += data[0]; - accum += data[1]; - accum += data[2]; - accum += data[3]; - accum += data[4]; - accum += data[5]; - accum += data[6]; - accum += data[7]; - accum += data[8]; - accum += data[9]; - accum += data[10]; - accum += data[11]; - accum += data[12]; - accum += data[13]; - accum += data[14]; - accum += data[15]; - - return accum; - } - - /// Returns the sum of squared elements - CUTLASS_HOST_DEVICE - Element norm(Element accum = Element()) const { - - accum += data[0] * data[0]; - accum += data[1] * data[1]; - accum += data[2] * data[2]; - accum += data[3] * data[3]; - accum += data[4] * data[4]; - accum += data[5] * data[5]; - accum += data[6] * data[6]; - accum += data[7] * data[7]; - accum += data[8] * data[8]; - accum += data[9] * data[9]; - accum += data[10] * data[10]; - accum += data[11] * data[11]; - accum += data[12] * data[12]; - accum += data[13] * data[13]; - accum += data[14] * data[14]; - accum += data[15] * data[15]; - - return accum; - } - - /// Returns square root of the norm - CUTLASS_HOST_DEVICE - Element magnitude() const { - return fast_sqrt(norm()); - } - - /// Returns the sum of diagonal elements - CUTLASS_HOST_DEVICE - Element trace(Element accum = Element()) const { - - accum += data[0]; - accum += data[5]; - accum += data[10]; - accum += data[15]; - - return accum; - } - - /// Returns 4-by-4 rotation matrix around the X axis - CUTLASS_HOST_DEVICE - static Matrix rotation_X(Element theta) { - Matrix m = identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(1, 1) = c; - m.at(1, 2) = -s; - m.at(2, 1) = s; - m.at(2, 2) = c; - - return m; - } - - /// Returns 4-by-4 rotation matrix around the Y axis - CUTLASS_HOST_DEVICE - static Matrix rotation_Y(Element theta) { - Matrix m = identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(0, 0) = c; - m.at(2, 0) = -s; - m.at(0, 2) = s; - m.at(2, 2) = c; - - return m; - } - - /// Returns 4-by-4 rotation matrix around the Z axis - CUTLASS_HOST_DEVICE - static Matrix rotation_Z(Element theta) { - Matrix m = Matrix::identity(); - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - m.at(0, 0) = c; - m.at(0, 1) = -s; - m.at(1, 0) = s; - m.at(1, 1) = c; - - return m; - } - - /// Returns a 4-by-4 rotation matrix around a unit-length axis - CUTLASS_HOST_DEVICE - static Matrix rotation(Element theta, Matrix const &u) { - Element x = u.data[0]; - Element y = u.data[1]; - Element z = u.data[2]; - - Element c = fast_cos(theta); - Element s = fast_sin(theta); - - Element one_minus_cos = Element(1) - fast_cos(theta); - - Matrix m; - - m.set_slice_3x3({ - c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, - y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, - z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos - }); - - return m; - } - - /// Returns a 4-by-4 reflection about the plane specified by the - /// unit-length normal vector n_unit - CUTLASS_HOST_DEVICE - static Matrix reflection(Matrix const &n_unit) { - - Element a = n_unit.data[0]; - Element b = n_unit.data[1]; - Element c = n_unit.data[2]; - - Matrix m = Matrix::identity(); - - m.set_slice_3x3({ - Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, - Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, - Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c - }); - - return m; - } - - /// Returns a perspective projection matrix typical of OpenGL applications - CUTLASS_HOST_DEVICE - static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) { - Element aspect = fovH / fovV; - Element f = Element(cos(fovV)) / Element(fovH); - Element Q = near_plane - far_plane; - - return Matrix( - f / aspect, 0, 0, 0, - 0, f, 0, 0, - 0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q, - 0, 0, -1, 0 - ); - } - - CUTLASS_HOST_DEVICE - static Matrix translation(Matrix const &v) { - return Matrix( - 1, 0, 0, v.data[0], - 0, 1, 0, v.data[1], - 0, 0, 1, v.data[2], - 0, 0, 0, 1 - ); - } - - /// Computes the determinant of a 4-by-4 matrix - CUTLASS_HOST_DEVICE - Element determinant(Element accum = Element()) const { - - accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(1, 3), at(2, 1), at(2, 2), at(2, 3), at(3, 1), at(3, 2), at(3, 3) }).determinant(); - accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(1, 3), at(2, 0), at(2, 2), at(2, 3), at(3, 0), at(3, 2), at(3, 3) }).determinant(); - accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(1, 3), at(2, 0), at(2, 1), at(2, 3), at(3, 0), at(3, 1), at(3, 3) }).determinant(); - accum -= at(0, 3) * Matrix({ at(1, 0), at(1, 1), at(1, 2), at(2, 0), at(2, 1), at(2, 2), at(3, 0), at(3, 1), at(3, 2) }).determinant(); - - return accum; - } - - /// Computes the inverse of a 4-by-4 matrix (ignores the optional argument) - CUTLASS_HOST_DEVICE - Matrix inverse(Element ignore = 1) const { - Matrix B = slice_2x2(0, 2); - Matrix A = slice_2x2(0, 0); - Matrix C = slice_2x2(2, 0); - Matrix D = slice_2x2(2, 2); - - Matrix D_inv = D.inverse(); - - Matrix E = (A - B * D_inv * C).inverse(); - - return Matrix::block( - E, -E * B * D_inv, - -D_inv * C * E, D_inv + D_inv * C * E * B * D_inv - ); - } - -}; - -/// Template alias for 4-by-4 matrix -template -using Matrix4x4 = Matrix; - - -/// Free function to infer element type from template arguments -template -CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( - Element _0_0, Element _0_1, Element _0_2, Element _0_3, - Element _1_0, Element _1_1, Element _1_2, Element _1_3, - Element _2_0, Element _2_1, Element _2_2, Element _2_3, - Element _3_0, Element _3_1, Element _3_2, Element _3_3 -) { - return Matrix4x4( - _0_0, _0_1, _0_2, _0_3, - _1_0, _1_1, _1_2, _1_3, - _2_0, _2_1, _2_2, _2_3, - _3_0, _3_1, _3_2, _3_3 - ); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Elementwise scalar multiplication -template -CUTLASS_HOST_DEVICE -Matrix operator*(Element s, Matrix const &rhs) { - return rhs.multiply(s); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_coord.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_coord.h deleted file mode 100644 index 85d447b1398e844011a798e2d818543f2d51bba4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_coord.h +++ /dev/null @@ -1,164 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a canonical coordinate for rank=2 matrices offering named indices. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes -/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord. -struct MatrixCoord : public Coord<2, int> { - -public: - - /// Integer-valued index - using Index = int; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// LongIndex type - using LongIndex = typename Base::LongIndex; - -private: - - /// Rows dimension - static int const kRow = 0; - - /// Columns dimension - static int const kColumn = 1; - -public: - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - MatrixCoord() { } - - /// Constructs from Coord<2> - CUTLASS_HOST_DEVICE - MatrixCoord(Coord<2, Index> const &coord): Base(coord) { } - - /// Helper to construct from a row and column - CUTLASS_HOST_DEVICE - MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { } - - /// Helper to construct from a row and column, which are LongIndex based - CUTLASS_HOST_DEVICE - MatrixCoord(LongIndex row, LongIndex column): Base(make_Coord(Index(row), Index(column))) { } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index const & row() const { return this->at(kRow); } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index & row() { return this->at(kRow); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index const & column() const { return this->at(kColumn); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index & column() { return this->at(kColumn); } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - MatrixCoord operator+(Base const& b) const { - return MatrixCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - MatrixCoord operator-(Base const& b) const { - return MatrixCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - MatrixCoord operator*(Base const& b) const { - return MatrixCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - MatrixCoord operator/(Base const& b) const { - return MatrixCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - MatrixCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - MatrixCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - MatrixCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - MatrixCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_shape.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_shape.h deleted file mode 100644 index 20d668b248daac24cf152ba6ec72c5d47ad319e9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/matrix_shape.h +++ /dev/null @@ -1,65 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a Shape template for matrix tiles -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Describes the size of a matrix tile -template < - int Row_, ///< rows of a matrix - int Column_ ///< columns of a matrix -> -struct MatrixShape { - static int const kRow = Row_; ///< rows of a matrix - static int const kColumn = Column_; ///< columns of a matrix - static int const kCount = Row_ * Column_; ///< total number of elements in a matrix - - // - // Static member functions - // - - CUTLASS_HOST_DEVICE - static Coord<2> toCoord() { - return make_Coord(kRow, kColumn); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_conversion.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_conversion.h deleted file mode 100644 index 7aad6c24193c19537340f50777ac62a645465902..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_conversion.h +++ /dev/null @@ -1,7123 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Boost-like numeric conversion operator for CUTLASS numeric types -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/thread/unary_op.h" - -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/bfloat16.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Floating-point rounding style similar to Standard Library's formats but supporting -/// additional rounding options. -enum class FloatRoundStyle { - round_indeterminate, ///< rounding mode unknown - round_toward_zero, ///< round toward zero - round_to_nearest, ///< round to nearest even - round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type - round_toward_infinity, ///< round toward infinity - round_toward_neg_infinity, ///< round toward negative infinity - round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero - round_half_ulp_trunc_dntz ///< like round_half_ulp_truncate, except denorms are rounded *toward* zero -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename S, - FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -> -struct NumericConverter { - - using result_type = T; - using source_type = S; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - return static_cast(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float => int32_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if __CUDA_ARCH__ - return __float2int_rn(s); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TONEAREST); - return static_cast(std::nearbyint(s)); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if __CUDA_ARCH__ - return __float2int_rz(s); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TOWARDZERO); - return (result_type)std::nearbyint(s); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float => int8_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if defined(__CUDA_ARCH__) - int32_t intermediate; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TONEAREST); - int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if defined(__CUDA_ARCH__) - int32_t intermediate; - asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TOWARDZERO); - int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if defined(__CUDA_ARCH__) - int32_t intermediate; - asm volatile("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TONEAREST); - int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if __CUDA_ARCH__ - int32_t intermediate; - asm volatile("cvt.rzi.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TOWARDZERO); - int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for cutlass::half_t => int8_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = cutlass::half_t; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - #if defined(__CUDA_ARCH__) - union { int8_t int8[2]; int16_t int16; }; - union { cutlass::half_t fp16; int16_t int16_in; }; - fp16 = s; - asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); - return int8[0]; - #elif !defined(__CUDACC_RTC__) - std::fesetround(FE_TONEAREST); - int32_t intermediate = (int32_t)std::nearbyint(static_cast(s)); - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float => integer_subbyte -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct NumericConverter, float, Round> { -private: - static constexpr bool result_is_signed = true; - -public: - using result_type = integer_subbyte; - using source_type = float; - static constexpr FloatRoundStyle round_style = Round; - - CUTLASS_HOST_DEVICE static result_type - convert(source_type const& src) { - using middle_type = int; - static_assert(8 * sizeof(middle_type) > Bits, "This conversion " - "requires that integer_subbyte have fewer representation bits " - "than the number of bits in int."); - - auto middle = NumericConverter::convert(src); - return NumericConverter::convert(middle); - } - - CUTLASS_HOST_DEVICE result_type - operator()(source_type const& s) const { - return convert(s); - } -}; - -template -struct NumericConverter, float, Round> { -private: - static constexpr bool result_is_signed = false; - -public: - using result_type = integer_subbyte; - using source_type = float; - static constexpr FloatRoundStyle round_style = Round; - - CUTLASS_HOST_DEVICE static result_type - convert(source_type const& src) { - using middle_type = unsigned; - static_assert(8 * sizeof(middle_type) > Bits, "This conversion " - "requires that integer_subbyte have fewer representation bits " - "than the number of bits in unsigned int."); - - auto middle = NumericConverter::convert(src); - return NumericConverter::convert(middle); - } - - CUTLASS_HOST_DEVICE result_type - operator()(source_type const& s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float <= cutlass::half_t -template -struct NumericConverter { - - using result_type = T; - using source_type = T; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - return s; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float <=> cutlass::half_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float <= cutlass::half_t -template -struct NumericConverter { - - using result_type = float; - using source_type = cutlass::half_t; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - result_type result = static_cast(s); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Specialization for round-to-nearest -template <> -struct NumericConverter { - - using result_type = cutlass::half_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - result_type result = static_cast(s); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Specialization for round-toward-zero -template <> -struct NumericConverter { - - using result_type = cutlass::half_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - /// Round toward zero - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & flt) { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - return cutlass::half_t(__float2half_rz(flt)); - #else - // software implementation rounds toward nearest even - unsigned const& s = reinterpret_cast(flt); - uint16_t sign = uint16_t((s >> 16) & 0x8000); - int32_t exp = int32_t((s >> 23) & 0xff) - 127; - int mantissa = s & 0x7fffff; - uint16_t u = 0; - - if ((s & 0x7fffffff) == 0) { - // sign-preserving zero - return cutlass::half_t::bitcast(sign); - } - - if (exp > 15) { - if (exp == 128 && mantissa) { - // not a number - u = 0x7fff; - } else { - // overflow to infinity - u = sign | 0x7c00; - } - return cutlass::half_t::bitcast(u); - } - - if (exp >= -14) { - // normal fp32 to normal fp16 - u = uint16_t((uint32_t(exp + 15) & 0x1f) << 10); - u = uint16_t(u | (mantissa >> 13)); - } else { - // normal single-precision to subnormal cutlass::half_t-precision representation - int rshift = (-14 - exp); - if (rshift < 32) { - mantissa |= (1 << 23); - mantissa = (mantissa >> rshift); - u = (uint16_t(mantissa >> 13) & 0x3ff); - } else { - mantissa = 0; - u = 0; - } - } - - u |= sign; - - return cutlass::half_t::bitcast(u); - - #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float <=> cutlass::bfloat16_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float <= cutlass::bfloat16_t -template -struct NumericConverter { - - using result_type = float; - using source_type = cutlass::bfloat16_t; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - return static_cast(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::bfloat16_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - return static_cast(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::bfloat16_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - uint32_t x32 = reinterpret_cast(s); - - #if defined(__CUDA_ARCH__) - if (::isfinite(s)) { - x32 += 0x8000; - } - #else - if (std::isfinite(s)) { - x32 += 0x8000; - } - #endif - - uint16_t x16 = uint16_t((x32 >> 16) & 0xffff); - return cutlass::bfloat16_t::bitcast(x16); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::bfloat16_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - uint32_t x32 = reinterpret_cast(s); - uint16_t x16 = uint16_t(x32 >> 16); - - return cutlass::bfloat16_t::bitcast(x16); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for float <=> cutlass::tfloat32_t -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for float <= cutlass::tfloat32_t -template -struct NumericConverter { - - using result_type = float; - using source_type = cutlass::tfloat32_t; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - return static_cast(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::tfloat32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - unsigned storage = reinterpret_cast(s); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 - asm volatile("cvt.rn.tf32.f32 %0, %1;" : "=r"(storage) : "r"(storage)); -#else - if ((storage & 0x7f800000) != 0x7f800000) { - - bool mantissa_bit = ((storage & (1 << 13)) != 0); - bool round_bit = ((storage & (1 << 12)) != 0); - bool sticky_bit = ((storage & ((1 << 12) - 1)) != 0); - - if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { - storage += uint32_t(1 << 13); - } - - // Note, the following is intentionally commented out. TF32 - // does not define the low order bits, so they may be left in - // an undefined state. - // - // By not truncating these bit explicitly, we avoid an extra logical - // operation. - // - // TF32 may be implicitly converted to float by performing this - // operation as needed. - // - // storage = (storage & ~0x1fff); - } - else if (storage & ~0xff800000) { - storage = 0x7fffffff; - } -#endif - - return cutlass::tfloat32_t::bitcast(storage); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::tfloat32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - return cutlass::tfloat32_t::round_half_ulp_truncate(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// This rounding operation is similar to half_ulp_truncate except it rounds denorms toward zero. -/// It avoids predicated code, though it requires a temporary register. -template <> -struct NumericConverter { - using result_type = cutlass::tfloat32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_trunc_dntz; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - unsigned y = reinterpret_cast(s); - y = y & 0xff800000; - float d = reinterpret_cast(y); - float z = d / float(1 << 11) + s; - - return reinterpret_cast(z); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - using result_type = cutlass::tfloat32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - uint32_t x = reinterpret_cast(s); - return cutlass::tfloat32_t::bitcast(x & 0xffffe000); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Conversion operator for float to cutlass::tfloat32_t big and small values -// -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, - FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate -> -struct NumericConverterFastF32 { - - // result_type holds big cutlass::tfloat32_t at idx(0) and small cutlass::tfloat32_t at idx(1) - using result_type = Array; - - // source data type - using source_type = float; - - // rounding styles for big and small part - static FloatRoundStyle const kRoundBig = RoundBig; - static FloatRoundStyle const kRoundSmall = RoundSmall; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - result_type result; - NumericConverter convert_big_; - NumericConverter convert_small_; - - // convert and fill cutlass::tfloat32_t big at idx 0 - result[0] = convert_big_(source); - - // convert and fill cutlass::tfloat32_t small at idx 1 - result[1] = convert_small_(source - static_cast(result[0])); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Conversion and Clamp operator for Integers -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename S -> -struct NumericConverterClamp { - - using result_type = T; - using source_type = S; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - NumericConverter convert_op; - result_type const kClamp_max = cutlass::platform::numeric_limits::max(); - result_type const kClamp_min = cutlass::platform::numeric_limits::lowest(); - if (s < (source_type)kClamp_min) - return kClamp_min; - if (s > (source_type)kClamp_max) - return kClamp_max; - return convert_op(s); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -// This converter is needed to enable cutlass::half_t output types when using int32_t accumulators. -// Since floating-point types do not require a clamp, this converter simply casts from -// the source type to cutlass::half_t. -template < - typename S -> -struct NumericConverterClamp { - - using result_type = cutlass::half_t; - using source_type = S; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const &source) { - return static_cast(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Conversion operator for Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Conversion operator for Array -template < - typename T, - typename S, - int N, - FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, - typename Transform = cutlass::transform::thread::UnaryTransform::Identity -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - static_assert(platform::is_same::value || - platform::is_same::value, - "Unary Operator not supported."); - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - result_type result; - NumericConverter convert_; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - if (platform::is_same::value) { - result[i] = convert_(s[i]); - } else { // conjugate - result[i] = conj(convert_(s[i])); - } - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - typename T, - int N, - FloatRoundStyle Round, - typename Transform -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - static_assert(platform::is_same::value || - platform::is_same::value, - "Unary Operator not supported."); - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const &source) { - if (platform::is_same::value) { - return source; - } else { - result_type result; - for (int i = 0; i < N; ++i) { - result[i] = conj(static_cast(source[i])); - } - return result; - } - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array, round to nearest -template <> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - Array result; - reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast(source)); - return result; - #else - NumericConverter convert_; - // NOTE: cutlass::Array is NOT an aggregate type and - // below `{}` does NOT conduct zero initialization. Below `{}` will - // conduct default initialization (calling default ctr). We use this syntax - // to resolve compiler warning on uninitialized member variable. - Array result{}; - result[0] = convert_(source[0]); - result[1] = convert_(source[1]); - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array, round to nearest -template -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - float2 result2 = __half22float2(reinterpret_cast<__half2 const &>(source)); - return { - float{result2.x}, - float{result2.y} - }; - #else - NumericConverter convert_; - return { - convert_(source[0]), - convert_(source[1]) - }; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - NumericConverter convert_element_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - if (N % 2) { - result[N - 1] = convert_element_(source[N - 1]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - NumericConverter convert_element_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - if (N % 2) { - result[N - 1] = convert_element_(source[N - 1]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array, round to nearest -template <> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned d; - - asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); - - return reinterpret_cast(d); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array, round to nearest with min/max saturation -template <> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest_satfinite; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned d; - - asm("cvt.rn.satfinite.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); - - return reinterpret_cast(d); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - NumericConverter convert_element_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - if (N % 2) { - result[N - 1] = convert_element_(source[N - 1]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#endif // if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Conditional guards to enable partial specialization for packed integers -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ - ((__CUDACC_VER_MAJOR__ > 10) || \ - ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericConverter convert_element_; - - result_type result; - - result[0] = convert_element_(source[0]); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - uint32_t tmp; - - asm volatile( - "cvt.pack.sat.s8.s32.b32 %0, %2, %1, 0;\n" - : "=r"(tmp) : "r"(source[0]), "r"(source[1])); - - uint16_t out = (tmp & 0xffff); - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned out; - - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.s8.s32.b32 r4, %4, %3, 0;" - "cvt.pack.sat.s8.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); - - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 4), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericConverter convert_element_; - - result_type result; - - result[0] = convert_element_(source[0]); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - uint32_t tmp; - - asm volatile( - "cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n" - : "=r"(tmp) : "r"(source[0]), "r"(source[1])); - - uint16_t out = (tmp & 0xffff); - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned out; - - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;" - "cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); - - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 4), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ - "}\n" : "=r"(out_fp16): "h"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e4m3_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e5m2x2 %0, %1;\n" \ - "}\n" : "=r"(out_fp16): "h"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e5m2_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e5m2x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - result_type out; - uint32_t& reg = reinterpret_cast(out); - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ - "}\n" : "=r"(reg): "h"(src_packed)); - - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::half_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;\n" \ - "}" \ - : "=h"(out) : "r"(reinterpret_cast(source))); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - result_type out; - uint32_t& reg = reinterpret_cast(out); - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e5m2x2 %0, %1;\n" \ - "}\n" : "=r"(reg): "h"(src_packed)); - - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::half_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;\n" \ - "}" \ - : "=h"(out) : "r"(reinterpret_cast(source))); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t res_half; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ - "}\n" : "=r"(res_half): "h"(src_packed)); - float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half)); - NumericArrayConverter converter; - return converter(reinterpret_cast const&>(res_float)); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::bfloat16_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - NumericArrayConverter converter; - Array res_float = converter(source); - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(res_float[0]), "f"(res_float[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t res_half; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e5m2x2 %0, %1;\n" \ - "}\n" : "=r"(res_half): "h"(src_packed)); - float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half)); - NumericArrayConverter converter; - return converter(reinterpret_cast const&>(res_float)); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::bfloat16_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - NumericArrayConverter converter; - Array res_float = converter(source); - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e5m2x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(res_float[0]), "f"(res_float[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -namespace detail { - -/// Special converters that can be used with 4 8-bit elements packed in a register. -/// Common use is for fast FP8 converters. -template < - typename T, - typename S, - FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, - typename Transform = cutlass::transform::thread::UnaryTransform::Identity -> -struct NumericArrayConverterPacked4Element { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - static_assert(platform::is_same::value || - platform::is_same::value, - "Unary Operator not supported."); - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - result_type result; - NumericConverter convert_; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - if (platform::is_same::value) { - result[i] = convert_(s[i]); - } - else { // conjugate - result[i] = conj(convert_(s[i])); - } - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = float_ue4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float_ue4m3_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = float_ue8m0_t; - - using result_type = Array; - using source_type = Array; - using BfloatArr = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.bf16x2.ue8m0x2 %0, lo;\n" \ - "cvt.rn.bf16x2.ue8m0x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - NumericArrayConverter bf2fp32_converter; - auto res0 = bf2fp32_converter(reinterpret_cast &>(out_fp16[0])); - auto res1 = bf2fp32_converter(reinterpret_cast &>(out_fp16[1])); - - result_type out; - out[0] = res0[0]; - out[1] = res0[1]; - out[2] = res1[0]; - out[3] = res1[1]; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array -template <> -struct NumericArrayConverterPacked4Element { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint32_t out; - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rp.satfinite.ue8m0x2.f32 lo, %2, %1;\n" \ - "cvt.rp.satfinite.ue8m0x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template <> -struct NumericArrayConverterPacked4Element { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint32_t out; - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rz.satfinite.ue8m0x2.f32 lo, %2, %1;\n" \ - "cvt.rz.satfinite.ue8m0x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - //default maps to RP mode. - return NumericArrayConverterPacked4Element{}(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::detail::float_e2m3_unpack8bits_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out; - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e2m3x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e2m3x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = cutlass::detail::float_e2m3_unpack8bits_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e2m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e2m3x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::detail::float_e3m2_unpack8bits_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out; - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e3m2x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e3m2x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = cutlass::detail::float_e3m2_unpack8bits_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e3m2x2 %0, lo;\n" \ - "cvt.rn.f16x2.e3m2x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = float; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16[2]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ - "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out[2]; - uint32_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::half_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - uint32_t const* src_packed = reinterpret_cast(&source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ - "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out[2]; - uint32_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ - "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ - "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::half_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - uint32_t const* src_packed = reinterpret_cast(&source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ - "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - // Convert f8 to float - NumericArrayConverterPacked4Element src2float; - Array tmp_floats = src2float(source); - - // Convert float to bf16 - result_type out; - Array* packed_tmp = reinterpret_cast*>(&tmp_floats); - Array* packed_out = reinterpret_cast*>(&out); - NumericArrayConverter float2result; - packed_out[0] = float2result(packed_tmp[0]); - packed_out[1] = float2result(packed_tmp[1]); - - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::bfloat16_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - // Convert bf16 to float - Array tmp; - Array* packed_tmp = reinterpret_cast*>(&tmp); - Array const* packed_source = reinterpret_cast const*>(&source); - NumericArrayConverter src2float; - packed_tmp[0] = src2float(packed_source[0]); - packed_tmp[1] = src2float(packed_source[1]); - - // Convert float to f8 - NumericArrayConverterPacked4Element float2result; - return float2result(tmp); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - // Convert f8 to float - NumericArrayConverterPacked4Element src2float; - Array tmp_floats = src2float(source); - - // Convert float to bf16 - result_type out; - Array* packed_tmp = reinterpret_cast*>(&tmp_floats); - Array* packed_out = reinterpret_cast*>(&out); - NumericArrayConverter float2result; - packed_out[0] = float2result(packed_tmp[0]); - packed_out[1] = float2result(packed_tmp[1]); - - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::bfloat16_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - // Convert bf16 to float - Array tmp; - Array* packed_tmp = reinterpret_cast*>(&tmp); - Array const* packed_source = reinterpret_cast const*>(&source); - NumericArrayConverter src2float; - packed_tmp[0] = src2float(packed_source[0]); - packed_tmp[1] = src2float(packed_source[1]); - - // Convert float to f8 - NumericArrayConverterPacked4Element float2result; - return float2result(tmp); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for: -// Array <=> Array -// Array <=> Array -// using packed converter under the hood -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename S, - int N, - FloatRoundStyle Round -> -struct PackedNumericArrayConverter { - using result_element = T; - using source_element = S; - - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using packed_result_type = Array; - using packed_source_type = Array; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - result_type result; - packed_result_type* packed_result = reinterpret_cast(&result); - const packed_source_type* packed_source = reinterpret_cast(&source); - - detail::NumericArrayConverterPacked4Element packed_converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - packed_result[i] = packed_converter(packed_source[i]); - } - - // Handle leftovers - NumericConverter converter; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N % 4; ++i) { - int idx = ((N / 4) * 4) + i; - result[idx] = converter(source[idx]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const{ - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float; - using source_element = float_ue8m0_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint32_t out_fp16; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.bf16x2.ue8m0x2 %0, %1;\n" \ - "}\n" : "=r"(out_fp16): "h"(src_packed)); - - NumericArrayConverter bf2fp32_converter; - auto res0 = bf2fp32_converter(reinterpret_cast &>(out_fp16)); - - result_type out; - out[0] = res0[0]; - out[1] = res0[1]; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template <> -struct NumericArrayConverter { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint16_t out; - asm volatile( \ - "{\n" \ - "cvt.rp.satfinite.ue8m0x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template <> -struct NumericArrayConverter { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) - uint16_t out; - asm volatile( \ - "{\n" \ - "cvt.rz.satfinite.ue8m0x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_ue8m0_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - return NumericArrayConverter{}(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float; - using source_element = float_ue4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out_fp16; - uint16_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ - "}\n" : "=r"(out_fp16): "h"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_ue4m3_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint16_t out; - - asm volatile( \ - "{\n" \ - "cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - - -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - - -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float; - using source_element = cutlass::float_e2m1_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t out_fp16[4]; - uint32_t const& src_packed = reinterpret_cast(source); - - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1, byte2, byte3;\n" \ - "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ - "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ - "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); - - float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); - float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); - float2 res2 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[2])); - float2 res3 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[3])); - - result_type out; - out[0] = res0.x; - out[1] = res0.y; - out[2] = res1.x; - out[3] = res1.y; - out[4] = res2.x; - out[5] = res2.y; - out[6] = res3.x; - out[7] = res3.y; - return out; - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_e2m1_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint32_t tmp; - asm volatile( \ - "{\n" \ - ".reg .b8 byte0;\n" \ - ".reg .b8 byte1;\n" \ - ".reg .b8 byte2;\n" \ - ".reg .b8 byte3;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" \ - "}" \ - : "=r"(tmp) : "f"(source[0]), "f"(source[1])); - - uint8_t out = (tmp & 0xff); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = cutlass::float_e2m1_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - unsigned out; - asm volatile( \ - "{\n" \ - ".reg .b8 byte0;\n" \ - ".reg .b8 byte1;\n" \ - ".reg .b8 byte2;\n" \ - ".reg .b8 byte3;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" \ - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" \ - "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3]), - "f"(source[4]), "f"(source[5]), "f"(source[6]), "f"(source[7])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_e2m1_t; - using source_element = float; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - uint16_t out; - asm volatile( \ - "{\n" \ - ".reg .b8 byte0;\n" \ - ".reg .b8 byte1;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ - "mov.b16 %0, {byte0, byte1};\n" \ - "}" \ - : "=h"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -/// Conversion is performed with saturation regardless of setting of -/// the `Round` template parameter. -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericConverter destination_converter; - result_type result; - result[0] = destination_converter(source[0]); - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericConverter destination_converter; - result_type result; - result[0] = destination_converter(source[0]); - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -// To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayFP32ToIntConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - static_assert(cutlass::platform::numeric_limits::is_integer, "the dest type has to be int."); - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - // Convert float to int - Array temporary; - - NumericArrayConverter compute_converter; - temporary = compute_converter(source); - - // Convert to int to int8_t - NumericArrayConverter destination_converter; - return destination_converter(temporary); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - - -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ - ((__CUDACC_VER_MAJOR__ > 10) || \ - ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned out; - - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" - "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" - "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" - "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) - : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), - "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); - - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - unsigned out; - - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" - "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" - "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" - "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) - : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), - "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); - - return reinterpret_cast(out); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#endif // Conditional guards to enable partial specialization for packed integers - -namespace detail { - - /* - A helper class that can vectorize a numeric converter with implementation for several vector widths. - - The vector widths must be giving in decreasing order or width, and must be a power of 2. - - The vector converters must produce identical results to the scalar converters for consistency. - */ - class VectorizedConverter { - private: - // Base case to handle remainder elements as scalars. - template - CUTLASS_DEVICE - static void convert_helper( - typename ArrayConverter::result_type& result, - typename ArrayConverter::source_type const& source) { - - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; - // If no more converters, handle the remaining elements as scalars. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int remainder = total_elements - Offset; - static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); - - typename ArrayConverter::ScalarConverter scalar_converter; - CUTLASS_PRAGMA_UNROLL - for (int i = Offset; i < ArrayConverter::result_type::kElements; ++i) { - result[i] = scalar_converter(ElementSrc(source[i])); - } - } - - template - CUTLASS_DEVICE - static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); - static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); - static_assert(cutlass::platform::is_same::value, - "ResultVectorArray must have the same type ArrayConverter::result_type"); - static_assert(cutlass::platform::is_same::value, - "SourceVectorArray must have the same type ArrayConverter::result_type"); - static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); - - static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); - - constexpr int vector_width = ResultVectorArray::kElements; - static_assert(ispow2(vector_width), "Vector width must be a power of 2"); - - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; - - constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; - constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; - - static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); - static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); - - constexpr int vector_offset = Offset / vector_width; - ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; - SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; - - // Convert the remaining elements as vectors. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int groups_of_vec = (total_elements - Offset) / vector_width; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < groups_of_vec; ++i) { - packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); - } - - constexpr int new_offset = Offset + vector_width * groups_of_vec; - // Recurse to handle other vector converters, or the scalar base case. - convert_helper(result, source); - } - - public: - /* - A method to convert vectors of elements using the packed_convert method of the converter. - - Converters using this class must implement packed convert and support 1 or more vector conversions. - */ - template - CUTLASS_DEVICE - static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); - } - }; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round, - int N -> -struct NumericArrayConverter { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e2m1_t; - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - CUTLASS_DEVICE - static result_type_packed_8 ptx_convert(source_type_packed_8 const &source) { - result_type_packed_8 out; - uint32_t* out_fp16 = reinterpret_cast(&out); - uint32_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1, byte2, byte3;\n" \ - "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ - "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ - "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); - return out; - } - - CUTLASS_DEVICE - static result_type_packed_4 ptx_convert(source_type_packed_4 const &source) { - result_type_packed_4 out; - uint32_t* out_fp16 = reinterpret_cast(&out); - uint16_t const& src_packed = reinterpret_cast(source); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1;\n" \ - "mov.b16 {byte0, byte1}, %2;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ - "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "h"(src_packed)); - return out; - } - - CUTLASS_DEVICE - static result_type_packed_2 ptx_convert(source_type_packed_2 const &source) { - result_type_packed_2 out; - uint32_t* out_fp16 = reinterpret_cast(&out); - uint16_t const& src_packed = static_cast(reinterpret_cast(source)); - asm volatile( \ - "{\n" \ - ".reg .b8 byte0, byte1;\n" \ - "mov.b16 {byte0, byte1}, %1;\n" \ - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ - "}\n" : "=r"(out_fp16[0]) : "h"(src_packed)); - return out; - } - #endif - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - - #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - return ptx_convert(source); - #else - PackedResultType result; - NumericConverter converter; - - const int k_packed = PackedResultType::kElements; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < k_packed; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); - - // Hold output FP8s in reg. We need 1 reg for every 4 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 2; - - src_reg &= 0x333333333333; // s14s12s10s8s6s4s2s0 - src_reg_shifted &= 0x333333333333; // s15s13s11s9s7s5s3s1 - - // [0, 1, -2, -1] encoded as FP8 - static constexpr uint32_t E4M3_LUT = 0xB8C03800; - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { - // This uses a look up table to convert packed int2s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 f8_6420, f8_7531;\n" - " prmt.b32 f8_6420, %4, 0, %2;\n" - " prmt.b32 f8_7531, %4, 0, %3;\n" - " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 - " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 - "}\n" - : "=r"(r[ii]), "=r"(r[ii+1]) - : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); - - // Hold output FP8s in reg. We need 1 reg for every 4 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 2; - - src_reg &= 0x333333333333; // u14u12u10u8u6u4u2u0 - src_reg_shifted &= 0x333333333333; // u15u13u11u9u7u5u3u1 - - // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t E4M3_LUT = 0x44403800; - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { - // This uses a look up table to convert packed uint2s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 f8_6420, f8_7531;\n" - " prmt.b32 f8_6420, %4, 0, %2;\n" - " prmt.b32 f8_7531, %4, 0, %3;\n" - " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 - " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 - "}\n" - : "=r"(r[ii]), "=r"(r[ii+1]) - : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); - - // Hold output FP8s in reg. We need 1 reg for every 4 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 2; - - src_reg &= 0x333333333333; // s14s12s10s8s6s4s2s0 - src_reg_shifted &= 0x333333333333; // s15s13s11s9s7s5s3s1 - - // [0, 1, -2, -1] encoded as FP8 - static constexpr uint32_t E4M3_LUT = 0xBCC03C00; - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { - // This uses a look up table to convert packed int2s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 f8_6420, f8_7531;\n" - " prmt.b32 f8_6420, %4, 0, %2;\n" - " prmt.b32 f8_7531, %4, 0, %3;\n" - " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 - " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 - "}\n" - : "=r"(r[ii]), "=r"(r[ii+1]) - : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); - - // Hold output FP8s in reg. We need 1 reg for every 4 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 2; - - src_reg &= 0x333333333333; // u14u12u10u8u6u4u2u0 - src_reg_shifted &= 0x333333333333; // u15u13u11u9u7u5u3u1 - - // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t E4M3_LUT = 0x42403C00; - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { - // This uses a look up table to convert packed uint2s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 f8_6420, f8_7531;\n" - " prmt.b32 f8_6420, %4, 0, %2;\n" - " prmt.b32 f8_7531, %4, 0, %3;\n" - " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 - " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 - "}\n" - : "=r"(r[ii]), "=r"(r[ii+1]) - : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - - static_assert(N % 8 == 0, "N must be a multiple of 8"); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { - - #if defined(__CUDA_ARCH__) - - if constexpr ( N == 8 ) { - - unsigned const& storage = reinterpret_cast(source); - unsigned out[2]; - - asm volatile( - "{\n" - " .reg .u32 tmp0, tmp1, tmp2;\n" - " shl.b32 tmp0, %2, 4;\n" // tmp0 = x1x2x3x4x5x6x7__ - " and.b32 tmp0, tmp0, 0xf0f0f0f0;\n" // tmp0 = x1__x3__x5__x7__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s1s3s5s7 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s1__s3__s5__s7__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x1__x3__x5__x7 - " or.b32 tmp2, tmp0, tmp1;\n" // tmp2 = y1y3y5y7 - " and.b32 tmp0, %2, 0xf0f0f0f0;\n" // tmp0 = x0__x2__x4__x6__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s0s2s4s6 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s0__s2__s4__s6__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x0__x2__x4__x6 - " or.b32 tmp0, tmp0, tmp1;\n" // tmp0 = y0y2y4y6 - " prmt.b32 %0, tmp2, tmp0, 0x5140;\n" // %0 = y0y1y2y3 - " prmt.b32 %1, tmp2, tmp0, 0x7362;\n" // %1 = y4y5y6y7 - "}\n" - : "=r"(out[0]), "=r"(out[1]) - : "r"(storage)); - - return reinterpret_cast(out); - - } else { - - NumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - #else - - result_type result; - NumericConverter convert_; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = convert_(source[i]); - } - - return result; - - #endif // __CUDA_ARCH__ - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses a lookup table to converts i4 -> e4m3. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); - - // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. - cutlass::AlignedArray r; - - // View the input as reg - uint32_t reg = to_reg(source); - - // Determines if to get from the signed or unsigned candidates - uint32_t sign = (reg & 0x88888888) >> 1; - - // Ignore sign bit when indexing into LUT - uint32_t lut_idx = (reg & 0x77777777); - - // Signed is OR'd with 0x32103210 to find the correct value in the LUT - const uint32_t final_prmt_base = 0x32103210; - - // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t POS_E4M3s_REG1 = 0x44403800; - // [4, 5, 6, 7] encoded as FP8 - static constexpr uint32_t POS_E4M3s_REG2 = 0x4E4C4A48; - // [-8, -7, -6, -5] encoded as FP8 - static constexpr uint32_t NEG_E4M3s_REG1 = 0xCACCCED0; - // [-4, -3, -2, -1] encoded as FP8 - static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; - - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { - uint32_t final_prmt_idx = final_prmt_base | sign; - - // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 pos_f8s, neg_f8s;\n" - " prmt.b32 pos_f8s, %1, %2, %5;\n" - " prmt.b32 neg_f8s, %3, %4, %5;\n" - " prmt.b32 %0, pos_f8s, neg_f8s, %6;\n" - "}\n" - : "=r"(r[ii]) - : "n"(POS_E4M3s_REG1), "n"(POS_E4M3s_REG2), "n"(NEG_E4M3s_REG1), "n"(NEG_E4M3s_REG2), - "r"(lut_idx), "r"(final_prmt_idx)); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses a lookup table to converts i4 -> e5m2. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); - - // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. - cutlass::AlignedArray r; - - // View the input as reg - uint32_t reg = to_reg(source); - - // Determines if to get from the signed or unsigned candidates - uint32_t sign = (reg & 0x88888888) >> 1; - - // Ignore sign bit when indexing into LUT - uint32_t lut_idx = (reg & 0x77777777); - - // Signed is OR'd with 0x32103210 to find the correct value in the LUT - const uint32_t final_prmt_base = 0x32103210; - - // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t POS_E5M2s_REG1 = 0x42403C00; - // [4, 5, 6, 7] encoded as FP8 - static constexpr uint32_t POS_E5M2s_REG2 = 0x47464544; - // [-8, -7, -6, -5] encoded as FP8 - static constexpr uint32_t NEG_E5M2s_REG1 = 0xC5C6C7C8; - // [-4, -3, -2, -1] encoded as FP8 - static constexpr uint32_t NEG_E5M2s_REG2 = 0xBCC0C2C4; - - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { - uint32_t final_prmt_idx = final_prmt_base | sign; - - // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 pos_f8s, neg_f8s;\n" - " prmt.b32 pos_f8s, %1, %2, %5;\n" - " prmt.b32 neg_f8s, %3, %4, %5;\n" - " prmt.b32 %0, pos_f8s, neg_f8s, %6;\n" - "}\n" - : "=r"(r[ii]) - : "n"(POS_E5M2s_REG1), "n"(POS_E5M2s_REG2), "n"(NEG_E5M2s_REG1), "n"(NEG_E5M2s_REG2), - "r"(lut_idx), "r"(final_prmt_idx)); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses a lookup table to converts u4 -> e4m3. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); - - // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. - cutlass::AlignedArray r; - - // View the input as reg - uint32_t reg = to_reg(source); - - // Determines if to get from the [0-7] or [8-15] candidates - uint32_t sign = (reg & 0x88888888) >> 1; - - // Ignore sign bit when indexing into LUT - uint32_t lut_idx = (reg & 0x77777777); - - // Signed is OR'd with 0x32103210 to find the correct value in the LUT - const uint32_t final_prmt_base = 0x32103210; - - // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t E4M3s_REG1 = 0x44403800; - // [4, 5, 6, 7] encoded as FP8 - static constexpr uint32_t E4M3s_REG2 = 0x4E4C4A48; - // [8, 9, 10, 11] encoded as FP8 - static constexpr uint32_t E4M3s_REG3 = 0x53525150; - // [12, 13, 14, 15] encoded as FP8 - static constexpr uint32_t E4M3s_REG4 = 0x57565554; - - - const int iters = PackedSrcType::kElements / 4; - #pragma unroll - for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { - uint32_t final_prmt_idx = final_prmt_base | sign; - - // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value - // as the index to prmt. - // It first select both the positive and negative candidates, then uses the sign bit to - // select the correct candidate. - asm volatile( - "{\n" - " .reg .b32 f8s_1, f8s_2;\n" - " prmt.b32 f8s_1, %1, %2, %5;\n" - " prmt.b32 f8s_2, %3, %4, %5;\n" - " prmt.b32 %0, f8s_1, f8s_2, %6;\n" - "}\n" - : "=r"(r[ii]) - : "n"(E4M3s_REG1), "n"(E4M3s_REG2), "n"(E4M3s_REG3), "n"(E4M3s_REG4), - "r"(lut_idx), "r"(final_prmt_idx)); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static void packed_convert_vec(PackedResultType& result, uint32_t src_reg) { - static_assert(offset == 0 || offset == 4, "Invalid offset"); - // Selects one of the bottom int4s and constructs: - // 8388608 + (x + 8) - // 8388608 + 16 * (x + 8) - // 8388608 + 256 * (x + 8) - // 8388608 + 4096 * (x + 8) - uint32_t const and_masks[4] = {0x0000000F, 0x000000F0, 0x00000F00, 0x0000F000}; - uint32_t const xor_masks[4] = {0x4B000008, 0x4B000080, 0x4B000800, 0x4B008000}; - - float const scales[4] = {1.f, 1.f / 16.f, 1.f / 256.f, 1.f / 4096.f}; - float const offsets[4] = {-8388616.f, -524296.f, -32776.f, -2056.f}; - - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - uint32_t* result_as_int = reinterpret_cast(&result); - - // For each operand, computes: - // r[i] = (r[i] & and_mask) ^ xor_mask - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < elements_to_convert; ++ii) { - asm volatile( - "{\n" - " lop3.b32 %0, %1, %2, %3, %4;\n" - "}\n" - : "=r"(result_as_int[offset + ii]) - : "r"(src_reg), "r"(and_masks[ii]), "r"(xor_masks[ii]), "n"(immLut)); - - result[offset + ii] = __fmaf_rn(result[offset + ii], scales[ii], offsets[ii]); - } - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 1, 2, 4 or 8 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - PackedResultType r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - constexpr int total_elements = PackedResultType::kElements == 8 ? 4 : PackedResultType::kElements; - packed_convert_vec<0, total_elements>(r, src_reg); - - - if (PackedResultType::kElements == 8) { - uint32_t src_reg_shifted = src_reg >> 16; - packed_convert_vec<4, 4>(r, src_reg_shifted); - } - return r; - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - CUTLASS_DEVICE - static int32_t to_int32(source_type_packed_2 const& source) { - return static_cast(reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static int32_t to_int32(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - PackedResultType r; - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ <= 800 - // View the input as reg - uint32_t src_reg = to_reg(source); - static constexpr int fp32_base = 0x4B400000; - uint32_t const prmt_indices[4] = {0x8880, 0x9991, 0xAAA2, 0xBBB3}; - - int* result_as_int = reinterpret_cast(&r); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < PackedResultType::kElements; ++ii) { - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_as_int[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); - } - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < PackedResultType::kElements; ++ii) - { - result_as_int[ii] += fp32_base; - r[ii] -= reinterpret_cast(fp32_base); - } - #else - int32_t x = to_int32(source); - int32_t t[4]; - constexpr int32_t mask[4] = {0x00000001, 0x00000100, 0x00010000, 0x01000000}; - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < PackedResultType::kElements; ++ii) { - t[ii] = __dp4a(x, mask[ii], 0); - r[ii] = static_cast(t[ii]); - } - #endif - - return r; - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - PackedResultType r; - // View the input as reg - uint32_t src_reg = to_reg(source); - - // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores - // the result in r (without introducing extra cvt.u32.u8 instruction) - uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; - uint32_t* result_as_int = reinterpret_cast(&r); - for (int ii = 0; ii < PackedResultType::kElements; ++ii) { - result_as_int[ii] = __byte_perm(src_reg, 0x4B000000, prmt_indices[ii]); - // Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result - r[ii] -= 8388608.f; - } - - return r; - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 4; - - // Below constructs the following temporary: - // f1f0 = {0x00, i3i2i1i0, 0x00, i3i2i1i0} - // f3f2 = {0x00, i5i4i3i2, 0x00, i5i4i3i2} - // f5f4 = {0x00, i7i6i5i4, 0x00, i7i6i5i4} - // f7f6 = {0x00, i9i8i7i6, 0x00, i9i8i7i6} - // f9f8 = {0x00, i11i10i9i8, 0x00, i11i10i9i8} - // f11f10 = {0x00, i13i12i11i10, 0x00, i13i12i11i10} - // f13f12 = {0x00, i15i14i13i12, 0x00, i15i14i13i12} - // f15f14 = {0x00, 0000i15i14, 0x00, 0000i15i14} - // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC - // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. - uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; - static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2])); - - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii + 1]) - : "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2])); - } - - // The below XOR does the following: - // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing - // 1024 + x + 2, 1024 + 4 * (x + 2) - // We use lop3 so that we can use 1 instruction for AND and XOR. - // static constexpr uint32_t xor_mask[2] = { 0x64086402, 0x64806420}; - // static constexpr uint32_t and_mask[2] = { 0x000C0003, 0x00C00030}; - static constexpr uint32_t xor_mask = 0x64086402; - static constexpr uint32_t and_mask = 0x000C0003; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2] - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ lop3.b32 %0, %0, %1, %2, %3; }\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // {-258, -1026} - static constexpr uint32_t hfma_bias_rep = 0xDC08E402; - // {1/4, 1} - static constexpr uint32_t hfma_scale_rep = 0x34003C00; - - // Scale and subtract the FP16s to get the original int4 number as FP16. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hfma2(fp16x2_val, - reinterpret_cast(hfma_scale_rep), - reinterpret_cast(hfma_bias_rep)); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 4; - - // Below constructs the following temporary: - // f1f0 = {0x00, u3u2u1u0, 0x00, u3u2u1u0} - // f3f2 = {0x00, u5u4u3u2, 0x00, u5u4u3u2} - // f5f4 = {0x00, u7u6u5u4, 0x00, u7u6u5u4} - // f7f6 = {0x00, u9u8u7u6, 0x00, u9u8u7u6} - // f9f8 = {0x00, u11u10u9u8, 0x00, u11u10u9u8} - // f11f10 = {0x00, u13u12u11u10, 0x00, u13u12u11u10} - // f13f12 = {0x00, u15u14u13u12, 0x00, u15u14u13u12} - // f15f14 = {0x00, 0000u15u14, 0x00, 0000u15u14} - // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC - // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. - uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; - static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2])); - - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii + 1]) - : "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2])); - } - - // The below XOR does the following: - // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing - // 1024 + x, 1024 + 4 * x - // We use lop3 so that we can use 1 instruction for AND and OR. - static constexpr uint32_t xor_mask = 0x64006400; - static constexpr uint32_t and_mask = 0x000C0003; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2] - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ lop3.b32 %0, %0, %1, %2, %3; }\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // {-256, -1024} - static constexpr uint32_t hfma_bias_rep = 0xDC00E400; - // {1/4, 1} - static constexpr uint32_t hfma_scale_rep = 0x34003C00; - - // Scale and subtract the FP16s to get the original int4 number as FP16. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hfma2(fp16x2_val, - reinterpret_cast(hfma_scale_rep), - reinterpret_cast(hfma_bias_rep)); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - - // Below constructs the following temporary: - // fp16s_01 = {0x00, i4_01, 0x00, i4_01} - // fp16s_23 = {0x00, i4_23, 0x00, i4_23} - // fp16s_45 = {0x00, i4_45, 0x00, i4_45} - // fp16s_67 = {0x00, i4_67, 0x00, i4_67} - // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC - // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. - uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; - static_assert(RegArray::kElements <= 4, "Too many inputs for I4 ->F16 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); - } - - // The below XOR does the following: - // 1) Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing - // 1024 + x + 8 OR 1024 + 16 * (x + 8), then using hfma to subtract 1032 from that - // 2) Adds 8 to the int4 value that we will process in the FP16 (for uint4, we can simply avoid this step) - // The AND does the following: - // 1) Clear the set bits for the int4 we will ignore. - // We use lop3 so that we can use 1 instruction for AND and XOR. - static constexpr uint32_t xor_mask = 0x64806408; - static constexpr uint32_t and_mask = 0xFFF0FF0F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask) ^ xor_mask - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // We will issue 2 hfmas that do the following: - // For the high FP16: - // Divide by 16 {packed as a operand} to get: - // 64 + (x + 8) - // x + 72 - // Subtract 72 {packed as c operand} to get x - // For the low FP16: - // 1024 + (x + 8) - // x + 1032 - // So, we subtract 1032 {packed as c operand} to get x - - // {-72, -1032} - static constexpr uint32_t hfma_bias_rep = 0xD480E408; - // {1 / 16, 1} - static constexpr uint32_t hfma_scale_rep = 0x2C003C00; - - const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); - const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); - // Scale and subtract the FP16s to get the original int4 number as FP16. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - // Below constructs the following temporary: - // fp16s_01 = {0x00, u4_01, 0x00, u4_01} - // fp16s_23 = {0x00, u4_23, 0x00, u4_23} - // fp16s_45 = {0x00, u4_45, 0x00, u4_45} - // fp16s_67 = {0x00, u4_67, 0x00, u4_67} - uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; - static_assert(RegArray::kElements <= 4, "Too many inputs for u4 -> f16 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); - } - - // The below XOR does the following: - // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing - // 1024 + x, then using hsub2 to subtract 1024 from that - static constexpr uint32_t or_mask = 0x64006400; - static constexpr uint32_t and_mask = 0x00F0000F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask) | or_mask - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(or_mask), "n"(immLut)); - - // We will issue 2 hfmas that do the following: - // For the high FP16: - // Divide by 16 {packed as a operand} to get: - // 64 + x - // Subtract 64 {packed as c operand} to get x - // For the low FP16: - // we subtract 1024 {packed as c operand} to get x - - static constexpr uint32_t hfma_bias = 0xD400E400; // {-64, -1024} - static constexpr uint32_t hfma_scale = 0x2C003C00; // {1 / 16, 1} - - { - __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hfma2(fp16x2_val, reinterpret_cast(hfma_scale), reinterpret_cast(hfma_bias)); - } - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) - auto result = reinterpret_cast(r); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < PackedResultType::kElements; ++i) { - int16_t tmp = source[i] + 26112 /* 0x6600 */; - result[i] = reinterpret_cast(tmp) - 1536.0_hf; - } - #endif - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t const prmt_indices[2] = {0x9180, 0xB3A2}; - - // Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0]) - // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) - // The inline ptx below uses `msb=0` and `msb=1` from the above link to sign-extend the sign bit in 0, 1, 2, 3 bytes of s8x4 - // into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively. - // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same result and doesn't sign-extend the sign bit. - // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from s8x2 to s16x2. - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(r[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); - } - - // In the absence of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve - // the same result as add.s16x2 instruction. - // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) - // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to - // three predefined constant values as follows: - // ta = 0xF0; - // tb = 0xCC; - // tc = 0xAA; - // kImmLut = F(ta, tb, tc); - // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA - static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; - - for (int ii = 0; ii < RegArray::kElements; ++ii) { - // The bit-wise operation executed below is `r[ii] = (r[ii] & 0x03FF03FF) ^ 0x66006600;` - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : - "=r"(r[ii]) : "r"(r[ii]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); - } - - static constexpr uint32_t bias_rep = 0x66006600; - const half2& bias = reinterpret_cast(bias_rep); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hsub2(fp16x2_val, bias); - } - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t const prmt_indices[2] = {0x5150, 0x5352}; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(r[ii]) : "r"(src_reg), "n"(start_byte_for_fp16), "r"(prmt_indices[ii])); - } - - static constexpr uint32_t bias_rep = 0x64006400; - const half2& bias = reinterpret_cast(bias_rep); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hsub2(fp16x2_val, bias); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); - - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted_two = src_reg >> 2; - uint32_t src_reg_shifted_four = src_reg >> 4; - uint32_t src_reg_shifted_six = src_reg >> 6; - - // Modified prmt indices for signed 2-bit values - uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; - - static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> BF16 vector converter"); - - // First pass: extract and sign extend the 2-bit values - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2])); - - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii + 1]) - : "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2])); - } - - // For signed 2-bit integers: - // 00 -> 0 (0) - // 01 -> 1 (1) - // 10 -> -2 (2 with sign extension) - // 11 -> -1 (3 with sign extension) - //static constexpr uint32_t sign_mask = 0x00020002; // Mask to check sign bit - static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits - - // Modified for signed range (-2 to 1) - // We'll construct numbers in the form 128 + (x + 2) and then subtract 130 - // to get back to our original range - static constexpr uint32_t xor_mask = 0x43024302; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // Bias represents 130 in bfloat16 format - // Subtracting 130 brings us back to our signed range (-2 to 1) - static constexpr uint32_t bias_rep = 0x43024302; // {130, 130} in bfloat16 - const __nv_bfloat162& bias = reinterpret_cast(bias_rep); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, bias); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_16 = Array; - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using source_type_packed_16 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_16 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); - - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted_two = src_reg >> 2; - uint32_t src_reg_shifted_four = src_reg >> 4; - uint32_t src_reg_shifted_six = src_reg >> 6; - - // Modified prmt indices for signed 2-bit values - uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; - - static_assert(RegArray::kElements <= 8, "Too many inputs for U2 -> BF16 vector converter"); - - // First pass: extract and sign extend the 2-bit values - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ii += 2) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2])); - - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii + 1]) - : "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2])); - } - - static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits - static constexpr uint32_t xor_mask = 0x43004300; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ lop3.b32 %0, %0, %1, %2, %3; }" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - static constexpr uint32_t bias_rep = xor_mask; // {128, 128} in bfloat16 - const __nv_bfloat162& bias = reinterpret_cast(bias_rep); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, bias); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 4; - - // Below constructs the following temporary: - uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; - static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ prmt.b32 %0, %1, %2, %3; }\n" - : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); - } - - // The below XOR does the following: - // 1) Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing - // 128 + (x + 8) and subtracting 136 to get x - static constexpr uint32_t xor_mask = 0x43084308; - static constexpr uint32_t and_mask = 0x000F000F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask) ^ xor_mask - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{ lop3.b32 %0, %0, %1, %2, %3; }\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // We will issue 2 bfmas that do the following: - // high BF16: - // hi_bf16 - 136, lo_bf16 - 136 - - // This is the BF16 {136, 136} represented as an integer. - static constexpr uint32_t bias_rep = 0x43084308; - const __nv_bfloat162& bias = reinterpret_cast(bias_rep); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, bias); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_8 = Array; - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_8 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_8 const& source) { - return reinterpret_cast(source); - } - - // The core converter uses bit tricks to construct a known FP16 number, then does a - // subtraction in FP16 for the final result. - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - - // Hold output FP16s in reg. We need 1 reg for every 2 elements - using RegArray = cutlass::AlignedArray; - RegArray r; - - // View the input as reg - uint32_t src_reg = to_reg(source); - uint32_t src_reg_shifted = src_reg >> 4; - - // Below constructs the following temporary: - // fp16s_01 = {0x00, u4_21, 0x00, u4_10} - // fp16s_23 = {0x00, u4_43, 0x00, u4_32} - // fp16s_45 = {0x00, u4_65, 0x00, u4_54} - // fp16s_67 = {0x000, u4_7, 0x00, u4_76} - static constexpr uint32_t prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; - static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter"); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{\n" - " prmt.b32 %0, %1, %2, %3;\n" - "}\n" - : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); - } - - static constexpr uint32_t xor_mask = 0x43004300; - static constexpr uint32_t and_mask = 0x000F000F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; - - // For each operand, computes: - // r[i] = (r[i] & and_mask) ^ xor_mask - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - - // We will issue 2 bfmas that do the following: - // high BF16: - // hi_bf16 - 128, lo_bf16 - 128 - - // This is the BF16 {128, 128} represented as an integer. - static constexpr uint32_t bias = xor_mask; - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < RegArray::kElements; ++ii) { - __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, reinterpret_cast(bias)); - } - - return reinterpret_cast(r); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - NumericArrayConverter convert_int8_to_f32; - Array tmp = convert_int8_to_f32(source); - NumericArrayConverter convert_f32_to_bf16; - return convert_f32_to_bf16(tmp); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct NumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - -private: - using result_type_packed_4 = Array; - using result_type_packed_2 = Array; - using source_type_packed_4 = Array; - using source_type_packed_2 = Array; - - using ScalarConverter = NumericConverter; - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_2 const& source) { - return static_cast( - reinterpret_cast(source)); - } - - CUTLASS_DEVICE - static uint32_t to_reg(source_type_packed_4 const& source) { - return reinterpret_cast(source); - } - - template - CUTLASS_DEVICE - static PackedResultType packed_convert(PackedSrcType const &source) { - - static_assert((platform::is_same::value && - platform::is_same::value) || - (platform::is_same::value && - platform::is_same::value), - "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - - NumericArrayConverter convert_uint8_to_f32; - Array tmp = convert_uint8_to_f32(source); - NumericArrayConverter convert_f32_to_bf16_; - return convert_f32_to_bf16_(tmp); - } - - friend class detail::VectorizedConverter; - -public: - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// FastNumericArrayConverter only works when the source is within center range. -/// Conversion operator for Array. See the comments before -/// FastLinearCombinationClamp. -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &s) { - NumericArrayConverter convert_; - - return convert_(s); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { return convert(s); } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int tmp = source[i] + 1262485504 /*0x4B400000*/; - result[i] = reinterpret_cast(tmp) - 12582912.0f; - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { return convert(s); } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - Array result; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - float tmp = source[i] + 12582912.0f; - result[i] = reinterpret_cast(tmp); - } - - result[0] = __byte_perm(result[0], result[1], 0x40); - result[2] = __byte_perm(result[2], result[3], 0x40); - result[0] = __byte_perm(result[0], result[2], 0x5410); - - return reinterpret_cast(result[0]); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { return convert(s); } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - static_assert(!(N % 4), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - FastNumericArrayConverter convert_vector_; - - result_type result; - - Array *result_ptr = - reinterpret_cast *>(&result); - Array const *source_ptr = - reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { return convert(s); } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines preferred rounding mode for a pair of types -template -struct PreferredRoundingMode { - static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; -}; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 900 -/// Defines preferred rounding mode for a pair of types -template <> -struct PreferredRoundingMode { - static FloatRoundStyle const kRound = FloatRoundStyle::round_half_ulp_truncate; -}; -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Packs predicates into an array. -template -struct PackPredicates { - using result_type = Array; - - static_assert(!(N % 4), "Must pack predicates in a count that is a multiple of 4"); - - CUTLASS_HOST_DEVICE - result_type operator()(bool const predicates[]) { - - result_type packed; - packed.clear(); - - int const kWordSize = 8; - uint8_t *bytes = reinterpret_cast(packed.data()); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int word_idx = (i / kWordSize); - int bit_idx = (i % kWordSize); - - uint8_t mask = static_cast((predicates[i] ? 1u : 0u) << bit_idx); - bytes[word_idx] = (bytes[word_idx] | mask); - } - return packed; - } -}; - -/// Packs predicates into an array -template -struct UnpackPredicates { - using result_type = Array; - - static_assert(!(N % 4), "Must unpack predicates in a count that is a multiple of 4"); - - CUTLASS_HOST_DEVICE - void operator()(bool predicates[], result_type const &packed) { - - int const kWordSize = 8; - uint8_t const *bytes = reinterpret_cast(packed.data()); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int word_idx = (i / kWordSize); - int bit_idx = (i % kWordSize); - - predicates[i] = bool((bytes[word_idx] >> bit_idx) & 0x1); - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_size.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_size.h deleted file mode 100644 index 0d8f2ada075c5bfc54ee3667b2153116647da7bf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_size.h +++ /dev/null @@ -1,98 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Top-level include for all CUTLASS numeric types. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines the size of an element in bits -template -struct sizeof_bits { - static constexpr int value = int(sizeof(T) * 8); -}; - -template -struct sizeof_bits : sizeof_bits {}; - -template -struct sizeof_bits : sizeof_bits {}; - -template -struct sizeof_bits : sizeof_bits {}; - -template <> -struct sizeof_bits { - static constexpr int value = 0; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the number of bytes required to hold a specified number of bits -template -CUTLASS_HOST_DEVICE -constexpr -R -bits_to_bytes(T bits) { - return (R(bits) + R(7)) / R(8); -} - -/// Returns the number of bits required to hold a specified number of bytes -template -CUTLASS_HOST_DEVICE -constexpr -R -bytes_to_bits(T bytes) { - return R(bytes) * R(8); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct is_subbyte { - static constexpr bool value = sizeof_bits::value < 8; -}; - -template -struct is_subbyte : is_subbyte {}; - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_types.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_types.h deleted file mode 100644 index 0d814ed29150b2a13131a1f4a7d3cc13174336c9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/numeric_types.h +++ /dev/null @@ -1,114 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Top-level include for all CUTLASS numeric types. -*/ -#pragma once - -#include "cute/util/type_traits.hpp" - -#include "cutlass/numeric_size.h" -#include "cutlass/integer_subbyte.h" -#include "cutlass/half.h" -#include "cutlass/bfloat16.h" -#include "cutlass/tfloat32.h" -#include "cutlass/float8.h" -#include "cutlass/uint128.h" -#include "cutlass/uint256.h" -#include "cutlass/exmy_base.h" -#include "cutlass/float_subbyte.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct index_sequence; - -template -struct index_sequence_helper : index_sequence_helper {}; - -template -struct index_sequence_helper<0, 0, Next...> { - using type = index_sequence<0, Next...>; -}; - -template -using make_index_sequence = typename index_sequence_helper::type; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Default case - no negative zero -template -struct has_negative_zero : CUTE_STL_NAMESPACE::false_type{}; - -// Float types that support negative zero -template <> struct has_negative_zero> : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero> : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero> : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero> : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; -template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; - -// Helper variable template -template -inline constexpr bool has_negative_zero_v = has_negative_zero::value; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Get the register type used in kernel -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct get_unpacked_element_type { - using type = T; -}; - -} // namespace detail - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/pipeline.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/pipeline.hpp deleted file mode 100644 index e9cf66a794fef4631b715e8b6009c99425b3330f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/pipeline.hpp +++ /dev/null @@ -1,38 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/pipeline/sm90_pipeline.hpp" -#include "cutlass/pipeline/sm100_pipeline.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp deleted file mode 100644 index 4014bd006f6e08feff24a82eb8f12ac11462c9ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm100_pipeline.hpp +++ /dev/null @@ -1,1328 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once -// - -// - -#include "cute/numeric/integral_constant.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/barrier.h" -#include "cutlass/pipeline/sm90_pipeline.hpp" -#include "sm90_pipeline.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -using namespace cute; - -enum class McastDirection { - kRow, - kCol, - kRowCol -}; -namespace detail { - -template -CUTLASS_DEVICE -uint16_t calculate_multicast_mask(ClusterShape cluster_shape, AtomThrShape_MNK atom_thr_shape, dim3 block_id_in_cluster) { - auto is_participant = [&](auto x, auto y) { - if constexpr (McastDir == McastDirection::kRowCol) { - return (x/size<0>(atom_thr_shape) == block_id_in_cluster.x/size<0>(atom_thr_shape) || // is same MMA cluster col - y/size<1>(atom_thr_shape) == block_id_in_cluster.y/size<1>(atom_thr_shape)); // is same MMA cluster row - } - else if constexpr (McastDir == McastDirection::kRow) { - return (x/size<0>(atom_thr_shape) == block_id_in_cluster.x/size<0>(atom_thr_shape)); // is same MMA cluster row - } - else { // (McastDir == McastDirection::kCol) - return (y/size<1>(atom_thr_shape) == block_id_in_cluster.y/size<1>(atom_thr_shape)); // is same MMA cluster col - } - }; - - uint16_t block_id_mask = 0; - auto cluster_layout = make_layout(cluster_shape); - // When MMA_2x1SM instructions are used, the definition of "same row" changes. - // With MMA_2x1SM, we need to send the notification for MMA completion to all - // 2x1 threadblocks of the cluster. Below is a 4x4 example where R are the threadblocks - // that receives the release for A/B buffers that threadblock (0,0) uses. - // Row&Col Row Col - // RRRR RRRR Cxxx - // RRRR RRRR Cxxx - // Rxxx xxxx Cxxx - // Rxxx xxxx Cxxx - CUTLASS_PRAGMA_UNROLL - for (int x = 0; x(cluster_shape); x++) { - CUTLASS_PRAGMA_UNROLL - for (int y = 0; y(cluster_shape); y++) { - if (is_participant(x,y)) { - block_id_mask |= (1 << cluster_layout(x,y, Int<0>{})); - } - } - } - return block_id_mask; -} - -template -CUTLASS_DEVICE -uint16_t calculate_umma_peer_mask(ClusterShape cluster_shape, AtomThrShape_MNK atom_thr_shape, dim3 block_id_in_cluster) { - uint16_t tmem_sync_mask = 0; - auto cluster_layout = make_layout(cluster_shape); - int block_id_in_cluster_x = (block_id_in_cluster.x / size<0>(AtomThrShape_MNK{})) * size<0>(AtomThrShape_MNK{}) ; - int block_id_in_cluster_y = (block_id_in_cluster.y / size<1>(AtomThrShape_MNK{})) * size<1>(AtomThrShape_MNK{}) ; - CUTLASS_PRAGMA_UNROLL - for (int x = 0; x < size<0>(AtomThrShape_MNK{}); x++) { - CUTLASS_PRAGMA_UNROLL - for (int y = 0; y < size<1>(AtomThrShape_MNK{}); y++) { - tmem_sync_mask |= (1 << cluster_layout(block_id_in_cluster_x + x, block_id_in_cluster_y + y, Int<0>{})); - } - } - - return tmem_sync_mask; -} -} // namespace detail - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA (producer) Async Pipeline class for Blackwell UMMA -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template > -class PipelineUmmaAsync { -public: - static constexpr uint32_t Stages = Stages_; - using AtomThrShape_MNK = AtomThrShape_MNK_; -private: - using Impl = PipelineAsync; -public: - using FullBarrier = typename Impl::FullBarrier; - using EmptyBarrier = typename Impl::EmptyBarrier; - using ProducerBarrierType = typename Impl::ProducerBarrierType; - using ConsumerBarrierType = typename Impl::ConsumerBarrierType; - using PipelineState = typename Impl::PipelineState; - using SharedStorage = typename Impl::SharedStorage; - using ThreadCategory = typename Impl::ThreadCategory; - using Params = typename Impl::Params; - - // Helper function to initialize barriers - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params) { - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); - CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); - } - cutlass::arch::fence_barrier_init(); - } - - template - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { - // Calculate producer mask - if (params_.role == ThreadCategory::Producer) { - // The leader threadblock executing the MMA_2x1SM instruction will signal its peer - // threadblock when it is done with MMA operations. tmem_sync_mask encodes the - // position of peer SMs in the cluster - tmem_sync_mask_ = detail::calculate_umma_peer_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); - } - } - - // Constructor by default initializes barriers and calculates masks. - // These operations can be explicity deferred by specifying InitBarriers and InitMasks. - // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. - template - CUTLASS_DEVICE - PipelineUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, InitBarriers{}) - , params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape); - } - } - - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return impl_.producer_try_acquire(state, skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.producer_acquire(state, barrier_token); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index()); - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - impl_.producer_tail(state); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return impl_.producer_get_barrier(state.index()); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_try_wait(state, skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.consumer_wait(state, barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - detail::pipeline_check_is_consumer(params_.role); - if constexpr (is_2sm_mma) { - consumer_release_2x1SM(state.index()); - } else { - impl_.consumer_release(state); - } - } - -private: - Impl impl_; - Params params_; - FullBarrier* full_barrier_ptr_ = nullptr; - EmptyBarrier* empty_barrier_ptr_ = nullptr; - uint16_t tmem_sync_mask_ = 0; - static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; - - CUTLASS_DEVICE - void producer_commit(uint32_t stage) { - detail::pipeline_check_is_producer(params_.role); - uint64_t* smem_ptr = reinterpret_cast(&full_barrier_ptr_[stage]); - if constexpr (is_2sm_mma) { - cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, tmem_sync_mask_); - } - else { - cutlass::arch::umma_arrive(smem_ptr); - } - } - - CUTLASS_DEVICE - void consumer_release_2x1SM(uint32_t stage) { - detail::pipeline_check_is_consumer(params_.role); - uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); - cutlass::arch::umma_arrive_2x1SM_sm0(smem_ptr); - static_assert(is_2sm_mma, "ERROR : AtomThrShape_MNK does not correspond to a 2SM MMMA"); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA (producer) Transform (consumer) Async Pipeline -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template < - int Stages_, - class AtomThrShape_MNK_ = Shape<_1,_1,_1> -> -class PipelineTmaTransformAsync { -public: - static constexpr uint32_t Stages = Stages_; - using AtomThrShape_MNK = AtomThrShape_MNK_; -private: - using Impl = PipelineTmaAsync; -public: - using FullBarrier = typename Impl::FullBarrier; - using EmptyBarrier = typename Impl::EmptyBarrier; - using ProducerBarrierType = typename Impl::ProducerBarrierType; - using ConsumerBarrierType = typename Impl::ConsumerBarrierType; - using PipelineState = typename Impl::PipelineState; - using SharedStorage = typename Impl::SharedStorage; - using ThreadCategory = typename Impl::ThreadCategory; - using Params = typename Impl::Params; - - // Constructor - template - CUTLASS_DEVICE - PipelineTmaTransformAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) - , params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape); - } - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape); - } - } - - template - CUTLASS_DEVICE - PipelineTmaTransformAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) - , params_(params) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) - , full_barrier_ptr_(&storage.full_barrier_[0]) { - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape, mcast_direction); - } - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape, mcast_direction); - } - } - - // Helper function to initialize barriers - template - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - constexpr int producer_arv_cnt = 1; - auto atom_thr_shape = AtomThrShape_MNK{}; - static constexpr bool IsDynamicCluster = not cute::is_static_v; - static_assert(IsDynamicCluster or ((cute::size<0>(cluster_shape) % cute::size<0>(atom_thr_shape) == 0) && - (cute::size<1>(cluster_shape) % cute::size<1>(atom_thr_shape) == 0))); - uint32_t const num_consumer_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); - uint32_t const multicast_consumer_arrival_count = ((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + - (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) * num_consumer_per_cluster; - CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); - CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - } - cutlass::arch::fence_barrier_init(); - } - - template - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { - auto atom_thr_shape = AtomThrShape_MNK{}; - - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; - uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? - (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) * num_consumer_per_cluster : // Mcast with row ctas - (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) * num_consumer_per_cluster; // Mcast with col ctas - - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - - } - cutlass::arch::fence_barrier_init(); - } - - template - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster(), McastDirection mcast_dir = McastDirection::kRowCol) { - // Calculate consumer mask - if (params_.role == ThreadCategory::Consumer) { - // Logic to optimally schedule Empty Arrives - // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) - int warp_idx = canonical_warp_idx_sync(); - int thread_idx = threadIdx.x; - auto cluster_size = cute::size(cluster_shape); - - // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) - if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { - auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warpgroup(thread_idx % NumThreadsPerWarpGroup, warp_idx); - is_signaling_thread_ = is_signaling_thread; - dst_blockid_ = dst_blockid; - } - else if (params_.num_consumers == 32) { - auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warp(thread_idx % 32); - is_signaling_thread_ = is_signaling_thread; - dst_blockid_ = dst_blockid; - } - else { - is_signaling_thread_ = 0; - #ifndef NDEBUG - asm volatile ("brkpt;\n" ::); - #endif - } - - // STEP 2: Find if this dst block-id needs an arrival for this problem - is_signaling_thread_ &= dst_blockid_ < cluster_size; - if(mcast_dir == McastDirection::kRowCol){ - is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); - } - if(mcast_dir == McastDirection::kRow){ - is_signaling_thread_ &= is_same_row(dst_blockid_, block_id_in_cluster, cluster_shape); - } - } - } - - template - CUTLASS_DEVICE - bool is_same_row(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { - return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) - // If we are in the same cluster column and using 2CTA MMA, only odd or only even CTAs sync with each other - && ((dst_block_id % cute::size<0>(cluster_shape)) % cute::size<0>(AtomThrShape_MNK{}) == - block_id.x % cute::size<0>(AtomThrShape_MNK{})) - ); - } - - template - CUTLASS_DEVICE - bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { - return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) || - ( - ((dst_block_id / cute::size<0>(cluster_shape)) == block_id.y) - // If we are in the same cluster column and using 2CTA MMA, only odd or only even CTAs sync with each other - && ((dst_block_id % cute::size<0>(cluster_shape)) % cute::size<0>(AtomThrShape_MNK{}) == - block_id.x % cute::size<0>(AtomThrShape_MNK{})) - )); - } - - //////////////////// - // Producer APIs - //////////////////// - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return impl_.producer_try_acquire(state, skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.producer_acquire(state, barrier_token); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state, uint32_t bytes) { - impl_.producer_commit(state, bytes); - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - impl_.producer_tail(state); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return impl_.producer_get_barrier(state); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_try_wait(state, skip_wait); - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_test_wait(state, skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state) { - impl_.consumer_wait(state); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token) { - impl_.consumer_wait(state, barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state, uint32_t skip = false) { - detail::pipeline_check_is_consumer(params_.role); - empty_barrier_ptr_[state.index()].arrive(dst_blockid_, is_signaling_thread_ & (!skip)); - } - -private: - Impl impl_; - uint32_t dst_blockid_ = 0; - uint32_t is_signaling_thread_ = 0; - FullBarrier *full_barrier_ptr_ = nullptr; - EmptyBarrier *empty_barrier_ptr_ = nullptr; - Params params_; -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA (consumer) Async Pipeline classes for Blackwell UMMA -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Producer-consumer pipeline implementation -// for UMMA producer. In this case, UMMA barrier arrives are used -// by producer_commit. Use case, accumulator generation as -// the result of MMA instructions. -template < - int Stages_, - class ClusterShape = Shape, - class AtomThrShape_MNK_ = Shape<_1,_1,_1> -> -class PipelineTmaUmmaAsync { -public: - static constexpr uint32_t Stages = Stages_; - using AtomThrShape_MNK = AtomThrShape_MNK_; -private: - using Impl = PipelineTmaAsync; -public: - using FullBarrier = typename Impl::FullBarrier; - using EmptyBarrier = typename Impl::EmptyBarrier; - using ProducerBarrierType = typename Impl::ProducerBarrierType; - using ConsumerBarrierType = typename Impl::ConsumerBarrierType; - using PipelineState = typename Impl::PipelineState; - using SharedStorage = typename Impl::SharedStorage; - using ThreadCategory = typename Impl::ThreadCategory; - using Params = typename Impl::Params; - - using McastDirection = McastDirection; - - // Helper function to initialize barriers - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - constexpr int producer_arv_cnt = 1; - auto atom_thr_shape = AtomThrShape_MNK{}; - uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + - (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; - CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); - CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - } - cutlass::arch::fence_barrier_init(); - } - - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { - auto atom_thr_shape = AtomThrShape_MNK{}; - - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - constexpr int producer_arv_cnt = 1; - uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? - cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas - cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas - - CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); - CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - } - cutlass::arch::fence_barrier_init(); - } - - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { - // Calculate consumer mask - if (params_.role == ThreadCategory::Consumer) { - auto cluster_layout = make_layout(cluster_shape); - block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); - } - } - - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { - // Calculate consumer mask - dim3 block_id_in_cluster = cute::block_id_in_cluster(); - auto cluster_layout = make_layout(cluster_shape); - if (mcast_direction == McastDirection::kRow) { - block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); - } - else { - block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); - } - } - - // Constructor by default initializes barriers and calculates masks. - // These operations can be explicity deferred by specifying InitBarriers and InitMasks. - // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. - template - CUTLASS_DEVICE - PipelineTmaUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) - , params_(params) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) - , full_barrier_ptr_(&storage.full_barrier_[0]) { - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape); - } - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape); - } - } - - template - CUTLASS_DEVICE - PipelineTmaUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) - , params_(params) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) - , full_barrier_ptr_(&storage.full_barrier_[0]) { - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape, mcast_direction); - } - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape, mcast_direction); - } - } - - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return impl_.producer_try_acquire(state, skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.producer_acquire(state, barrier_token); - } - - CUTLASS_DEVICE - void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { - impl_.producer_expect_transaction(state, transaction_bytes); - } - - // NOP for TMA based mainloop - CUTLASS_DEVICE - void producer_commit(PipelineState state, uint32_t bytes) { - impl_.producer_commit(state, bytes); - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - impl_.producer_tail(state); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return impl_.producer_get_barrier(state); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_try_wait(state, skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.consumer_wait(state, barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index(), false); - } - -private: - Impl impl_; - Params params_; - EmptyBarrier *empty_barrier_ptr_; - FullBarrier *full_barrier_ptr_; - uint16_t block_id_mask_ = 0; - static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; - - // Consumer signalling Producer of completion - // Ensures all blocks in the Same Row and Column get notifed. - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip) { - detail::pipeline_check_is_consumer(params_.role); - uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); - if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 - if (!skip) { - cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); - } - } - else { - if (!skip) { - if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { - cutlass::arch::umma_arrive(smem_ptr); - } - else { - cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); - } - } - } - } -}; - -// Producer-consumer pipeline implementation -// for UMMA consumer. In this case, UMMA barrier arrives are -// used by consumer_release. -template > -class PipelineUmmaConsumerAsync { -public: - static constexpr uint32_t Stages = Stages_; - using AtomThrShape_MNK = AtomThrShape_MNK_; -private: - using Impl = PipelineAsync; -public: - using FullBarrier = typename Impl::FullBarrier; - using EmptyBarrier = typename Impl::EmptyBarrier; - using ProducerBarrierType = typename Impl::ProducerBarrierType; - using ConsumerBarrierType = typename Impl::ConsumerBarrierType; - using PipelineState = typename Impl::PipelineState; - using SharedStorage = typename Impl::SharedStorage; - using ThreadCategory = typename Impl::ThreadCategory; - using Params = typename Impl::Params; - - template - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { - // Calculate consumer mask - if (params_.role == ThreadCategory::Consumer) { - // The leader threadblock executing the MMA_2x1SM instruction will signal its peer - // threadblock when it is done with MMA operations. tmem_sync_mask encodes the - // position of peer SMs in the cluster - tmem_sync_mask_ = detail::calculate_umma_peer_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); - } - } - - // Constructor by default initializes barriers and calculates masks. - // These operations can be explicity deferred by specifying InitBarriers and InitMasks. - // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. - template - CUTLASS_DEVICE - PipelineUmmaConsumerAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, InitBarriers{}) - , params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape); - } - } - - //////////////////// - // Producer APIs - //////////////////// - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return impl_.producer_try_acquire(state, skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - impl_.producer_acquire(state, barrier_token); - } - - template - CUTLASS_DEVICE - void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { - cute::forward(user_defined_arrive_op)(producer_get_barrier(state)); - producer_commit(state); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - if constexpr (is_2sm_mma) { - producer_commit_2x1SM(state.index()); - } else { - impl_.producer_commit(state); - } - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - impl_.producer_tail(state); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return impl_.producer_get_barrier(state.index()); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_try_wait(state, skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - if (barrier_token == BarrierStatus::WaitAgain) { - impl_.consumer_wait(state); - } - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - -private: - Impl impl_; - Params params_; - FullBarrier* full_barrier_ptr_ = nullptr; - EmptyBarrier* empty_barrier_ptr_ = nullptr; - uint16_t tmem_sync_mask_ = 0; - static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; - - CUTLASS_DEVICE - void producer_commit_2x1SM(uint32_t stage) { - detail::pipeline_check_is_producer(params_.role); - uint64_t* smem_ptr = reinterpret_cast(&full_barrier_ptr_[stage]); - cutlass::arch::umma_arrive_2x1SM_sm0(smem_ptr); - static_assert(is_2sm_mma, "ERROR : AtomThrShape_MNK does not correspond to a 2SM MMMA"); - } - - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - detail::pipeline_check_is_consumer(params_.role); - uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); - if constexpr (is_2sm_mma) { - cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, tmem_sync_mask_); - } - else { - cutlass::arch::umma_arrive(smem_ptr); - } - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// CLC Async Pipeline class for Blackwell UMMA -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace PipelineDetail { - -template -using PipelineCLCFetchAsyncPipelineState = cutlass::PipelineState; - -template -struct PipelineCLCFetchAsyncSharedStorage { - using FullBarrier = cutlass::arch::ClusterTransactionBarrier; - using EmptyBarrier = cutlass::arch::ClusterBarrier; - - FullBarrier full_barrier_[static_cast(Stages_)]; - EmptyBarrier empty_barrier_[static_cast(Stages_)]; -}; - -} // namespace PipelineDetail - -template > -class PipelineCLCFetchAsync { - -public: - static constexpr uint32_t Stages = Stages_; - using PipelineState = PipelineDetail::PipelineCLCFetchAsyncPipelineState; - using SharedStorage = PipelineDetail::PipelineCLCFetchAsyncSharedStorage; - using FullBarrier = typename SharedStorage::FullBarrier; - using EmptyBarrier = typename SharedStorage::EmptyBarrier; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - uint32_t transaction_bytes = 0; - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t is_leader = 0; - uint32_t num_consumers = 0; - uint32_t producer_blockid = 0; - uint32_t producer_arv_count = 0; - uint32_t consumer_arv_count = 0; - int initializing_warp = 0; - }; - - // Constructor - CUTLASS_DEVICE - PipelineCLCFetchAsync(SharedStorage& storage, Params const& params) : - params_(params), - full_barrier_ptr_(&storage.full_barrier_[0]), - empty_barrier_ptr_(&storage.empty_barrier_[0]) { - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); - CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); - } - cutlass::arch::fence_barrier_init(); - - cluster_size_ = []() { auto cs = cute::cluster_shape(); return cs.x * cs.y; }(); - } - - // Constructor - CUTLASS_DEVICE - PipelineCLCFetchAsync(SharedStorage& storage, Params const& params, ClusterShape cluster_shape) - : params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - int warp_idx = canonical_warp_idx_sync(); - if (warp_idx == params.initializing_warp) { - // Barrier FULL and EMPTY init - CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); - CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); - } - cutlass::arch::fence_barrier_init(); - - cluster_size_ = cute::size<0>(cluster_shape) - * cute::size<1>(cluster_shape) - * cute::size<2>(cluster_shape); - } - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return producer_try_acquire(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - producer_acquire(state.index(), state.phase(), barrier_token); - } - - // Manual completion of transaction count - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index(), state.phase()); - } - - // Prevents early exit of producer blocks in Cluster. - // Does NOT reset transaction bytes. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - detail::pipeline_check_is_producer(params_.role); - for (int count = 0; count < Stages; ++count) { - bool done = empty_barrier_ptr_[state.index()].test_wait(state.phase()); - if (!done) { - empty_barrier_ptr_[state.index()].wait(state.phase()); - } - ++state; - } - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_try_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - consumer_wait(state.index(), state.phase(), barrier_token); - } - - // Consumer signalling Producer of completion - // Notifies the producer block in the Cluster - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - - CUTLASS_HOST_DEVICE - uint32_t producer_get_barrier(PipelineState state) { - return cute::cast_smem_ptr_to_uint(reinterpret_cast(&full_barrier_ptr_[state.index()])); - } - -private: - FullBarrier *full_barrier_ptr_ = nullptr; - EmptyBarrier *empty_barrier_ptr_ = nullptr; - Params params_; - int lane_idx_ = canonical_lane_idx(); - int cluster_size_; - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_producer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_stat = empty_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_stat)}; - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { - detail::pipeline_check_is_producer(params_.role); - // 1. Wait for empty barrier to be ready - // 2. Set the transaction bytes set to occur on the Full barrier for all blocks - if (barrier_token == BarrierStatus::WaitAgain) { - empty_barrier_ptr_[stage].wait(phase); - } - - full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes, lane_idx_, uint32_t(lane_idx_ < cluster_size_)); - } - - CUTLASS_DEVICE - void producer_commit(uint32_t stage, uint32_t phase) { - int cluster_size_ = []() { auto cs = cute::cluster_shape(); return cs.x * cs.y; }(); - full_barrier_ptr_[stage].complete_transaction(lane_idx_, params_.transaction_bytes, uint32_t(lane_idx_ < cluster_size_)); - } - - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_stat = full_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_stat)}; - } - - // Wait for producer to commit transactions - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { - detail::pipeline_check_is_consumer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_release(uint32_t stage) { - detail::pipeline_check_is_consumer(params_.role); - empty_barrier_ptr_[stage].arrive(params_.producer_blockid); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Empty Pipeline class -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -class PipelineEmpty { -public: - static constexpr uint32_t Stages = 0; - using PipelineState = cutlass::PipelineState<0>; - struct Params {}; - struct SharedStorage {}; - - // Constructor - CUTLASS_DEVICE - PipelineEmpty(SharedStorage& storage, Params const& params) {} - - // Constructor - CUTLASS_DEVICE - PipelineEmpty(SharedStorage&& storage, Params const& params) {} - - // Constructor with throwaway ClusterShape - template > - CUTLASS_DEVICE - PipelineEmpty(SharedStorage&& storage, Params const& params, ClusterShape) {} - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA (producer - consumer) Async Pipeline classes for Blackwell Sparse UMMA -// This is designed for the pattern that kernel has two different staged tensors. (AB and metadata) -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Producer-consumer pipeline implementation -// for UMMA producer. In this case, UMMA barrier arrives are used -// by producer_commit. Use case, accumulator generation as -// the result of MMA instructions. -template < - int Stages_, - class ClusterShape = Shape, - class AtomThrShape_MNK_ = Shape<_1,_1,_1> -> -class PipelineTmaSparseUmmaAsync { -public: - static constexpr uint32_t Stages = Stages_; - using AtomThrShape_MNK = AtomThrShape_MNK_; -private: - using Impl = PipelineTmaUmmaAsync; -public: - using FullBarrier = typename Impl::FullBarrier; - using EmptyBarrier = typename Impl::EmptyBarrier; - using ProducerBarrierType = typename Impl::ProducerBarrierType; - using ConsumerBarrierType = typename Impl::ConsumerBarrierType; - using PipelineState = typename Impl::PipelineState; - using SharedStorage = typename Impl::SharedStorage; - using ThreadCategory = typename Impl::ThreadCategory; - using Params = typename Impl::Params; - - struct ParamsMetadata { - uint32_t transaction_bytes = 0; - uint32_t metadata_transaction_bytes = 0; - }; - - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { - Impl::init_barriers(storage, params, cluster_shape); - } - - CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { - impl_.init_masks(cluster_shape, block_id_in_cluster); - } - - // Constructor by default initializes barriers and calculates masks. - // These operations can be deferred by specifying InitBarriers and InitMasks. - // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. - template - CUTLASS_DEVICE - PipelineTmaSparseUmmaAsync(SharedStorage& storage, Params params, ParamsMetadata params_metadata, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) - , params_(params) - , params_metadata_(params_metadata) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) - , full_barrier_ptr_(&storage.full_barrier_[0]) { - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape); - } - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_masks(cluster_shape); - } - } - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return impl_.producer_try_acquire(state, skip_wait); - } - - // Customized for metadata load - CUTLASS_DEVICE - void producer_acquire(PipelineState state, bool load_e, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - producer_acquire(state.index(), state.phase(), load_e, barrier_token); - } - - // Customized for metadata load - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - producer_acquire(state, true, barrier_token); - } - - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - return impl_.producer_tail(state); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return impl_.producer_get_barrier(state); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return impl_.consumer_try_wait(state, skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - return impl_.consumer_wait(state, barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - return impl_.consumer_release(state); - } - -private: - Impl impl_; - Params params_; - ParamsMetadata params_metadata_; - EmptyBarrier *empty_barrier_ptr_{nullptr}; - FullBarrier *full_barrier_ptr_{nullptr}; - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, bool load_e, ProducerToken barrier_token) { - detail::pipeline_check_is_producer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - empty_barrier_ptr_[stage].wait(phase); - } - uint32_t bytes_now = load_e ? params_metadata_.transaction_bytes + params_metadata_.metadata_transaction_bytes : params_metadata_.transaction_bytes; - - if (params_.is_leader) { - full_barrier_ptr_[stage].arrive_and_expect_tx(bytes_now); - } - } - -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp deleted file mode 100644 index aae17d98aafc045be0bfda867cad95717b19e74d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pipeline/sm90_pipeline.hpp +++ /dev/null @@ -1,1388 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/layout.hpp" -#include "cute/layout_composed.hpp" // cute::composition -#include "cute/swizzle.hpp" // cute::Swizzle -#include "cute/swizzle_layout.hpp" // cute::composition -#include "cute/util/type_traits.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cute/container/array.hpp" -#include "cute/numeric/integral_constant.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/detail/dependent_false.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using namespace cute; - -namespace detail { - -// Helper function for DEBUG checks -template -CUTLASS_DEVICE -bool pipeline_is_producer(ThreadCategory role) { - return (role == ThreadCategory::Producer || role == ThreadCategory::ProducerConsumer); -} - -template -CUTLASS_DEVICE -void pipeline_check_is_producer(ThreadCategory role) { - #ifndef NDEBUG - if (!pipeline_is_producer(role)) { - asm volatile ("brkpt;\n" ::); - } - #endif -} - -template -CUTLASS_DEVICE -bool pipeline_is_consumer(ThreadCategory role) { - return (role == ThreadCategory::Consumer || role == ThreadCategory::ProducerConsumer); -} - -template -CUTLASS_DEVICE -void pipeline_check_is_consumer(ThreadCategory role) { - #ifndef NDEBUG - if (!pipeline_is_consumer(role)) { - asm volatile ("brkpt;\n" ::); - } - #endif -} - -CUTLASS_DEVICE -cute::tuple spread_arrivals_to_warp(int thread_idx_in_warp) { - constexpr uint32_t MaxClusterSize = 16; - bool is_signaling_thread = (thread_idx_in_warp % (32 / MaxClusterSize)) == 0; - auto layout = Layout,Stride<_4, _1>>{}; - uint32_t thread_row = thread_idx_in_warp / 8; - uint32_t thread_col = (thread_idx_in_warp % 8) / 2; - uint32_t dst_blockid = layout(thread_row, thread_col); - return cute::make_tuple(is_signaling_thread, dst_blockid); -} - -CUTLASS_DEVICE -cute::tuple spread_arrivals_to_warpgroup(int thread_idx_in_warpgroup, int warp_idx) { - constexpr uint32_t MaxClusterSize = 16; - bool is_signaling_thread = (thread_idx_in_warpgroup % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; - auto layout = cute::composition(Swizzle<2,0,-2>{}, - Layout,Stride<_4,_1>>{}); - uint32_t thread_row = warp_idx % 4; - uint32_t thread_col = (thread_idx_in_warpgroup / 8) % 4; - uint32_t dst_blockid = layout(thread_row, thread_col); - return cute::make_tuple(is_signaling_thread, dst_blockid); -} -} // namespace detail - -enum class BarrierStatus : uint32_t { - WaitAgain = 0u, - WaitDone = 1u, -}; - -class ArrivalToken { -public: - CUTLASS_HOST_DEVICE - ArrivalToken(BarrierStatus barrier_status) : barrier_status_(barrier_status) {} - - CUTLASS_HOST_DEVICE - ArrivalToken() = delete; - - CUTLASS_HOST_DEVICE - BarrierStatus get() const { - return barrier_status_; - } - - CUTLASS_HOST_DEVICE - bool operator==(ArrivalToken const& other) const { - return barrier_status_ == other.get(); - } - -private: - BarrierStatus barrier_status_; - - CUTLASS_HOST_DEVICE - friend bool operator==(const ArrivalToken& left, const BarrierStatus& right) { - return left.get() == right; - } - - CUTLASS_HOST_DEVICE - friend bool operator==(const BarrierStatus& left, const ArrivalToken& right) { - return left == right.get(); - } - - CUTLASS_HOST_DEVICE - friend bool operator!=(const ArrivalToken& left, const BarrierStatus& right) { - return left.get() != right; - } - - CUTLASS_HOST_DEVICE - friend bool operator!=(const BarrierStatus& left, const ArrivalToken& right) { - return left != right.get(); - } -}; - -class ProducerToken : public ArrivalToken { - using ArrivalToken::ArrivalToken; -}; - -class ConsumerToken : public ArrivalToken { - using ArrivalToken::ArrivalToken; -}; - -// Circular Buffer Index + Associated Phase -// Assumes only one operation possible - i.e., ++ -template -struct PipelineState { - - static constexpr uint32_t Stages = Stages_; - - int index_ = 0; - uint32_t phase_ = 0; - uint32_t count_ = 0; - - CUTLASS_DEVICE - PipelineState(): index_{}, phase_{}, count_{} {} - - CUTLASS_DEVICE - PipelineState(int index, uint32_t phase, uint32_t count) - : index_(index) - , phase_(phase) - , count_(count) {} - - CUTLASS_DEVICE - int index() const { - return index_; - } - - CUTLASS_DEVICE - uint32_t phase() const { - return phase_; - } - - CUTLASS_DEVICE - uint32_t count() const { - return count_; - } - - CUTLASS_DEVICE - void operator++() { - if constexpr (Stages > 0) { - ++index_; - ++count_; - if (index_ == Stages) { - index_ = 0; - phase_ ^= 1; - } - } - } - - CUTLASS_DEVICE - PipelineState& operator+=(uint32_t num_iterations) { - return advance(num_iterations); - } - - CUTLASS_DEVICE - PipelineState& operator=(PipelineState const& other) { - index_ = other.index(); - phase_ = other.phase(); - count_ = other.count(); - return *this; - } - - CUTLASS_DEVICE - PipelineState& advance(uint32_t num_iterations) { - if constexpr (Stages > 0) { - // Number of iterations cross over the stage boundary => flipped phase - if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { - phase_ ^= 1; - } - // How many times number of iterations cross over the stage boundary and - // end up on a odd number => flipped phase - if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { - phase_ ^= 1; - } - index_ = (index_ + num_iterations) % Stages; - count_ += num_iterations; - } - return *this; - } - - CUTLASS_DEVICE - static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { - return start_state.advance(num_iterations); - } -}; - -template -CUTLASS_DEVICE -PipelineState make_producer_start_state() { - // Producer starts with an opposite phase as the buffers are initially empty - constexpr int InitialProducerStage = 0; - constexpr uint32_t InitialProducerPhase = 1; - constexpr uint32_t InitialProducerCount = 0; - return {InitialProducerStage, InitialProducerPhase, InitialProducerCount}; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA load (producer) Async Pipeline class -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Assumptions : Constructor is visible Cluster-wide (as it needs a Cluster-Sync) -// We have exactly one thread elected in the Producer as the "leader" -// Currently, it is optional to elect a leader for the Consumers -template -class PipelineTmaAsync { -public: - using FullBarrier = cutlass::arch::ClusterTransactionBarrier; - using EmptyBarrier = cutlass::arch::ClusterBarrier; - using ProducerBarrierType = FullBarrier::ValueType; - using ConsumerBarrierType = EmptyBarrier::ValueType; - static constexpr uint32_t Stages = Stages_; - using PipelineState = cutlass::PipelineState; - - struct SharedStorage { - FullBarrier full_barrier_[Stages]; - EmptyBarrier empty_barrier_[Stages]; - }; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - uint32_t transaction_bytes = 0; - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t is_leader = 0; - uint32_t num_consumers = 0; // Number of consumer threads - uint32_t num_producers = 1; // Number of producer threads - int initializing_warp = 0; - }; - - template - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { - int warp_idx = canonical_warp_idx_sync(); - bool is_initializing_warp = (warp_idx == 0); - is_initializing_warp = (warp_idx == params.initializing_warp); - if (is_initializing_warp) { - // Barrier FULL and EMPTY init - uint32_t const producer_arv_cnt = params.num_producers; - uint32_t const num_consumer_warpgroups_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); - uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 - if (cute::size(cluster_shape) > 1) { - multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * - num_consumer_warpgroups_per_cluster; - } - CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); - CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - } - cutlass::arch::fence_barrier_init(); - } - - template - CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - int warp_idx = canonical_warp_idx_sync(); - int thread_idx = threadIdx.x; - int lane_predicate = cute::elect_one_sync(); - - static_assert(cute::is_same_v || cute::is_same_v); - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_, cluster_shape); - } - - if constexpr (cute::is_same_v) { - // Logic to optimally schedule Empty Arrives - // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) - dim3 block_id = cute::block_id_in_cluster(); - auto cluster_size = cute::size(cluster_shape); - - if (cluster_size == 1) { - is_signaling_thread_ = true; - dst_blockid_ = 0; - } - else { - // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) - if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { - auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warpgroup(thread_idx % NumThreadsPerWarpGroup, warp_idx); - is_signaling_thread_ = is_signaling_thread; - dst_blockid_ = dst_blockid; - } - else if (params_.num_consumers == 32) { - auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warp(thread_idx % 32); - is_signaling_thread_ = is_signaling_thread; - dst_blockid_ = dst_blockid; - } - else { - is_signaling_thread_ = 0; - #ifndef NDEBUG - asm volatile ("brkpt;\n" ::); - #endif - } - - // STEP 2: Find if this dst block-id needs an arrival for this problem - is_signaling_thread_ &= dst_blockid_ < cluster_size; - is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); - } - } - } - - // Constructor - template - CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) - : PipelineTmaAsync(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } - - template - CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) - : PipelineTmaAsync(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } - - template - CUTLASS_DEVICE - bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { - return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) || - ( - ((dst_block_id / cute::size<0>(cluster_shape)) == block_id.y) - )); - } - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return producer_try_acquire(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state) { - producer_acquire(state.index(), state.phase()); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token) { - producer_acquire(state.index(), state.phase(), barrier_token); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state, uint32_t bytes) { - producer_commit(state.index(), bytes); - } - - template - CUTLASS_DEVICE - void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { - cute::forward(user_defined_arrive_op)(producer_get_barrier(state.index()));; - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - detail::pipeline_check_is_producer(params_.role); - for (int count = 0; count < Stages; ++count) { - empty_barrier_ptr_[state.index()].wait(state.phase()); - ++state; - } - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return producer_get_barrier(state.index()); - } - - CUTLASS_DEVICE - void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { - producer_expect_transaction(state.index(), transaction_bytes); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_try_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_test_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state) { - consumer_wait(state.index(), state.phase()); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token) { - consumer_wait(state.index(), state.phase(), barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - -private: - uint32_t dst_blockid_ = 0; - uint32_t is_signaling_thread_ = 0; - FullBarrier *full_barrier_ptr_ = nullptr; - EmptyBarrier *empty_barrier_ptr_ = nullptr; - Params params_; - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_producer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase) { - empty_barrier_ptr_[stage].wait(phase); - - if (params_.is_leader) { - full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); - } - #ifndef NDEBUG - if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { - asm volatile ("brkpt;\n" ::); - } - - // Most likely you have elected more than one leader - if (params_.is_leader && (threadIdx.x % 32 != 0)) { - asm volatile ("brkpt;\n" ::); - } - #endif - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { - detail::pipeline_check_is_producer(params_.role); - if (barrier_token != BarrierStatus::WaitDone) { - empty_barrier_ptr_[stage].wait(phase); - } - - if (params_.is_leader) { - full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); - } - #ifndef NDEBUG - if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { - asm volatile ("brkpt;\n" ::); - } - - // Most likely you have elected more than one leader - if (params_.is_leader && (threadIdx.x % 32 != 0)) { - asm volatile ("brkpt;\n" ::); - } - #endif - } - - CUTLASS_DEVICE - void producer_expect_transaction(uint32_t stage, uint32_t transaction_bytes) { - detail::pipeline_check_is_producer(params_.role); - if (params_.is_leader) { - full_barrier_ptr_[stage].expect_transaction(transaction_bytes); - } - } - - // NOP for TMA based mainloop - CUTLASS_DEVICE - void producer_commit(uint32_t stage, uint32_t bytes) { - // Below code is used only for unit-testing (in the absence of TMA commit) - #if CUTLASS_UNIT_TEST_PIPELINE - if (params_.is_leader) { - // STEP 1 : Commit to self - full_barrier_ptr_[stage].complete_transaction(bytes); - - // STEP 2 : Commit to other blocks in our cluster - auto cluster_shape = cute::cluster_shape(); - Layout block_layout_in_cluster = make_layout(cluster_shape); - dim3 local_block_id = cute::block_id_in_cluster(); - - CUTLASS_PRAGMA_UNROLL - for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { - uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); - full_barrier_ptr_[stage].complete_transaction(dst_block_id, bytes, n!=local_block_id.y); - } - - CUTLASS_PRAGMA_UNROLL - for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { - uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); - full_barrier_ptr_[stage].complete_transaction(dst_block_id, bytes, m!=local_block_id.x); - } - } - #endif - } - - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); - return {static_cast(barrier_status)}; - } - - // Wait for producer to commit transactions (done by TMA) - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase) { - detail::pipeline_check_is_consumer(params_.role); - full_barrier_ptr_[stage].wait(phase); - } - - // Wait for producer to commit transactions (done by TMA) - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { - detail::pipeline_check_is_consumer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - full_barrier_ptr_[stage].wait(phase); - } - } - - // Consumer signalling Producer of completion - // Ensures all blocks in the Same Row and Column get notifed. - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - detail::pipeline_check_is_consumer(params_.role); - empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signaling_thread_ & (!skip)); - #ifndef NDEBUG - if (params_.role == ThreadCategory::Producer || params_.role == ThreadCategory::NonParticipant) { - asm volatile ("brkpt;\n" ::); - } - #endif - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA store pipeline class -// producer-only class, no async barriers between threads because consumer is TMA unit -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template < - int Stages_, - // The number of committed TMA store batches that can be in flight upon return of producer acquire - int UnacquiredStages_ = Stages_-1 -> -class PipelineTmaStore { -public: - static constexpr uint32_t Stages = Stages_; - static_assert(Stages_ > 0); - static_assert(UnacquiredStages_ >= 0); - static constexpr uint32_t UnacquiredStages = static_cast(UnacquiredStages_); - using PipelineState = cutlass::PipelineState; - - struct Params { - bool always_wait = false; - }; - - CUTLASS_DEVICE - PipelineTmaStore(Params params = {}) : params_(params) {} - - //////////////////// - // Producer APIs - //////////////////// - // Wait for the least recently committed batch of TMA stores to complete - CUTLASS_DEVICE - void producer_acquire(PipelineState state) { - producer_acquire(state.index(), state.count()); - } - - // Commit the most recently issued batch of TMA stores - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index(), state.count()); - } - - // Wait for all TMA stores to complete - CUTLASS_DEVICE - void producer_tail([[maybe_unused]] PipelineState state) { - tma_store_wait<0>(); - } - -private: - Params params_; - - // Wait for the least recently committed batch of TMA stores to complete - // or until at most UnacquiredStages TMA store batches are in-flight (if specified) - CUTLASS_DEVICE - void producer_acquire([[maybe_unused]] uint32_t stage, uint32_t count) { - if (params_.always_wait || count > UnacquiredStages) { - tma_store_wait(); - } - } - - // Commit the most recently issued batch of TMA stores - CUTLASS_DEVICE - void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t count) { - tma_store_arrive(); - } -}; - -template <> -class PipelineTmaStore< /* Stages_ = */ 0, /* UnacquiredStages = Stages_ - 1 = */ -1 > { -public: - static constexpr uint32_t Stages = 0; - static constexpr uint32_t UnacquiredStages = 0; - using PipelineState = cutlass::PipelineState; - - struct Params { - bool always_wait = false; - }; - - PipelineTmaStore() = default; - CUTLASS_DEVICE - PipelineTmaStore(Params params) : params_(params) {} - - //////////////////// - // Producer APIs - //////////////////// - - template - CUTLASS_DEVICE - void producer_acquire(PipelineState /* state */, - ThisTemplateParameterExistsOnlyForDependentFalse* /* unused */ = nullptr) { - static_assert(cutlass::detail::dependent_false, - "It is never valid to call PipelineTmaStore<0>::producer_acquire"); - } - - // Commit the most recently issued batch of TMA stores - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index(), state.count()); - } - - // Wait for all TMA stores to complete - CUTLASS_DEVICE - void producer_tail([[maybe_unused]] PipelineState state) { - tma_store_wait<0>(); - } - -private: - Params params_; - - // Commit the most recently issued batch of TMA stores - CUTLASS_DEVICE - void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t count) { - tma_store_arrive(); - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Simple producer-consumer async Pipeline class using producer transaction barriers -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template -class PipelineTransactionAsync { -public: - using FullBarrier = cutlass::arch::ClusterTransactionBarrier; - using EmptyBarrier = cutlass::arch::ClusterBarrier; - using ProducerBarrierType = FullBarrier::ValueType; - using ConsumerBarrierType = EmptyBarrier::ValueType; - static constexpr uint32_t Stages = Stages_; - using PipelineState = cutlass::PipelineState; - - struct SharedStorage { - cute::array full_barrier_; - cute::array empty_barrier_; - }; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t transaction_bytes = 0; - uint32_t producer_arv_count = 1; - uint32_t consumer_arv_count = 1; - uint32_t dst_blockid = cute::block_rank_in_cluster(); - int initializing_warp = 0; - }; - - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params const& params) { - FullBarrier *full_barrier_ptr = storage.full_barrier_.data(); - EmptyBarrier *empty_barrier_ptr = storage.empty_barrier_.data(); - int warp_idx = canonical_warp_idx_sync(); - bool is_initializing_warp = (warp_idx == 0); - is_initializing_warp = (warp_idx == params.initializing_warp); - - if (is_initializing_warp) { - // Barrier FULL and EMPTY init - CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); - CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - full_barrier_ptr, empty_barrier_ptr, params.producer_arv_count, params.consumer_arv_count); - } - cutlass::arch::fence_barrier_init(); - } - - // Constructor - template - CUTLASS_DEVICE - PipelineTransactionAsync(SharedStorage& storage, Params const& params, InitBarriers = cute::true_type{}) - : params_(params) - , full_barrier_ptr_(storage.full_barrier_.data()) - , empty_barrier_ptr_(storage.empty_barrier_.data()) { - - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - - static_assert(cute::is_same_v || cute::is_same_v); - - if constexpr (cute::is_same_v) { - init_barriers(storage, params); - } - - } - - // Constructor - CUTLASS_DEVICE - PipelineTransactionAsync(SharedStorage& storage, Params const& params) : - PipelineTransactionAsync(storage, params, cute::true_type{}) { } - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return producer_try_acquire(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - producer_acquire(state.index(), state.phase(), barrier_token); - } - - // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread - CUTLASS_DEVICE - void producer_expect_transaction(PipelineState state) { - producer_expect_transaction(state.index()); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index()); - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - for (int count = 0; count < Stages; ++count) { - producer_acquire(state); - ++state; - } - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return producer_get_barrier(state.index()); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_try_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_test_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - consumer_wait(state.index(), state.phase(), barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - -private: - FullBarrier *full_barrier_ptr_ = nullptr; - EmptyBarrier *empty_barrier_ptr_ = nullptr; - Params params_; - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_producer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { - detail::pipeline_check_is_producer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - empty_barrier_ptr_[stage].wait(phase); - } - } - - // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread - CUTLASS_DEVICE - void producer_expect_transaction(uint32_t stage) { - detail::pipeline_check_is_producer(params_.role); - full_barrier_ptr_[stage].expect_transaction(params_.transaction_bytes); - } - - CUTLASS_DEVICE - void producer_commit(uint32_t stage) { - detail::pipeline_check_is_producer(params_.role); - full_barrier_ptr_[stage].arrive(params_.dst_blockid); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } - - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { - detail::pipeline_check_is_consumer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - detail::pipeline_check_is_consumer(params_.role); - empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Simple producer-consumer async Pipeline class -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace PipelineDetail { - template - using PipelineAsyncPipelineState = cutlass::PipelineState; - - template - struct PipelineAsyncSharedStorage { - using FullBarrier = cutlass::arch::ClusterBarrier; - using EmptyBarrier = cutlass::arch::ClusterBarrier; - - FullBarrier full_barrier_[Stages]; - EmptyBarrier empty_barrier_[Stages]; - }; -}; - -template -class PipelineAsync { -public: - static constexpr uint32_t Stages = Stages_; - using SharedStorage = PipelineDetail::PipelineAsyncSharedStorage; - using FullBarrier = typename SharedStorage::FullBarrier; - using EmptyBarrier = typename SharedStorage::EmptyBarrier; - using ProducerBarrierType = typename FullBarrier::ValueType; - using ConsumerBarrierType = typename EmptyBarrier::ValueType; - using PipelineState = PipelineDetail::PipelineAsyncPipelineState; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t producer_arv_count = 1; - uint32_t consumer_arv_count = 1; - uint32_t dst_blockid = cute::block_rank_in_cluster(); - int initializing_warp = 0; - }; - - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params) { - int warp_idx = canonical_warp_idx_sync(); - bool is_initializing_warp = (warp_idx == 0); - is_initializing_warp = (warp_idx == params.initializing_warp); - if (is_initializing_warp) { - // Barrier FULL and EMPTY init - CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); - CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); - } - cutlass::arch::fence_barrier_init(); - } - - template - CUTLASS_DEVICE - PipelineAsync( - SharedStorage& storage, - Params const& params, - InitBarriers = {}) : - params_(params), - full_barrier_ptr_(&storage.full_barrier_[0]), - empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params_); - } - } - - CUTLASS_DEVICE - PipelineAsync( - SharedStorage& storage, - Params const& params) : - PipelineAsync(storage, params, cute::true_type{}) { } - - // Default assumption when only storage is passed is : - // => single producer, single consumer & they are in the same block (within the Cluster) - CUTLASS_DEVICE - PipelineAsync(SharedStorage& storage) - : PipelineAsync(storage, {}, cute::true_type{}) {} - - //////////////////// - // Producer APIs - //////////////////// - // Four member functions are always used in pairs: - // - // * producer_try_acquire and producer_acquire, and - // * consumer_try_wait and consumer_wait. - // - // The two functions with "try" in their names are called "try" functions, - // and the other two are conceptually "finalize" functions. - // The "try" function in each pair starts the process of waiting on the barrier to flip. - // It opportunistically waits for an implementation-dependent timeout. - // Whether or not the barrier has flipped yet, the try function will return a token. - // If the token indicates that the barrier has not flipped, - // then the token must be passed into the corresponding "finalize" function. - // The finalize function will then block until the barrier has flipped. - // If the token indicates that the barrier _has_ flipped, - // then it is still correct to pass it into the finalize function. - // The finalize function will return immediately in that case. - CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { - return producer_try_acquire(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { - producer_acquire(state.index(), state.phase(), barrier_token); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index()); - } - - template - CUTLASS_DEVICE - void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { - cute::forward(user_defined_arrive_op)(producer_get_barrier(state.index())); - producer_commit(state); - } - - // Prevents early exit of producer blocks in Cluster. - // This should be called once before kernel exits. - CUTLASS_DEVICE - void producer_tail(PipelineState state) { - for (int count = 0; count < Stages; ++count) { - producer_acquire(state); - ++state; - } - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { - return producer_get_barrier(state.index()); - } - - //////////////////// - // Consumer APIs - //////////////////// - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_try_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(PipelineState state, uint32_t skip_wait = false) { - return consumer_test_wait(state.index(), state.phase(), skip_wait); - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { - consumer_wait(state.index(), state.phase(), barrier_token); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } - -private: - Params params_; - FullBarrier *full_barrier_ptr_; - EmptyBarrier *empty_barrier_ptr_; - - CUTLASS_DEVICE - ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_producer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { - detail::pipeline_check_is_producer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - empty_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void producer_commit(uint32_t stage) { - detail::pipeline_check_is_producer(params_.role); - full_barrier_ptr_[stage].arrive(); - } - - CUTLASS_DEVICE - ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { - detail::pipeline_check_is_consumer(params_.role); - if (skip_wait) { - return {BarrierStatus::WaitDone}; - } - bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); - return {static_cast(barrier_status)}; - } - - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase) { - detail::pipeline_check_is_consumer(params_.role); - bool done = full_barrier_ptr_[stage].test_wait(phase); - if (!done) { - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { - detail::pipeline_check_is_consumer(params_.role); - if (barrier_token == BarrierStatus::WaitAgain) { - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_release(uint32_t stage) { - detail::pipeline_check_is_consumer(params_.role); - empty_barrier_ptr_[stage].arrive(params_.dst_blockid); - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Barrier to ensure an Ordered Sequence between -// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages -// i.e., for all i < j - only after id "i" arrives at a particular stage "m" -// will the wait() for id "j" succeed for the same stage -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace PipelineDetail { - -template -struct OrderedSequenceBarrierSharedStorage { - using Barrier = cutlass::arch::ClusterBarrier; - Barrier barrier_[SequenceDepth][SequenceLength]; -}; - -} // namespace PipelineDetail - -template -class OrderedSequenceBarrier { -public: - static constexpr int SequenceDepth = SequenceDepth_; - static constexpr int SequenceLength = SequenceLength_; - using SharedStorage = - PipelineDetail::OrderedSequenceBarrierSharedStorage; - using Barrier = typename SharedStorage::Barrier; - - struct Params { - uint32_t group_id; - uint32_t group_size; - int initializing_warp = 0; - }; - -private: - // In future this Params object can be replaced easily with a CG object - Params params_; - Barrier *barrier_ptr_; - PipelineState stage_; - - static constexpr int Depth = SequenceDepth; - static constexpr int Length = SequenceLength; - -public: - OrderedSequenceBarrier() = delete; - OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; - OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; - OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; - OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; - ~OrderedSequenceBarrier() = default; - - CUTLASS_DEVICE - OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : - params_(params), - barrier_ptr_(&storage.barrier_[0][0]), - // Group 0 - starts with an opposite phase - stage_({0, params.group_id == 0, 0}) { - -#if (__CUDA_ARCH__ >= 1000) - int warp_idx = canonical_warp_idx_sync(); - - // Barrier FULL, EMPTY init - if (warp_idx == params.initializing_warp) { - int arv_cnt = params.group_size; - CUTLASS_ASSERT(arv_cnt > 0 && "Arrive count must be non-zero"); - constexpr int Stages = Depth * Length; - cutlass::arch::detail::initialize_barrier_array_aligned( - barrier_ptr_, arv_cnt); - } -#else - - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - CUTLASS_ASSERT(params.group_size > 0 && "Group size must be non-zero"); - - // Barrier FULL, EMPTY init - // Init is done only by the one elected thread of the block - if (warp_idx == 0 && lane_predicate) { - for (int d = 0; d < Depth; ++d) { - for (int l = 0; l < Length; ++l) { - barrier_ptr_[d * Length + l].init(params.group_size); - } - } - } -#endif - cutlass::arch::fence_barrier_init(); - } - - // Wait on a stage to be unlocked - CUTLASS_DEVICE - void wait() { - get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); - } - - // Signal completion of Stage and move to the next stage - // (group_id) signals to (group_id+1) - CUTLASS_DEVICE - void arrive() { - int signalling_id = (params_.group_id + 1) % Length; - get_barrier_for_current_stage(signalling_id).arrive(); - ++stage_; - } - - CUTLASS_DEVICE - void advance() { - ++stage_; - } - -private: - - CUTLASS_DEVICE - Barrier& get_barrier_for_current_stage(int group_id) { - return barrier_ptr_[stage_.index() * Length + group_id]; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Synchronization call. Blocks until barriers are initialized in shared memory. -CUTLASS_DEVICE -void -pipeline_init_wait(int cluster_size) { - if (cluster_size > 1) { - cute::cluster_wait(); - } - else { - __syncthreads(); - } -} - -// Used to guarantee that the Pipeline init is visible -// to all producers and consumer threadblocks in the cluster -CUTLASS_DEVICE -void -pipeline_init_arrive_relaxed(int cluster_size) { - if (cluster_size > 1) { - cute::cluster_arrive_relaxed(); - } - else { - __syncthreads(); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // end namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pitch_linear_coord.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pitch_linear_coord.h deleted file mode 100644 index 1b782ecef78928ade707daac617b8707bf720eb6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/pitch_linear_coord.h +++ /dev/null @@ -1,181 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template defining a shape used by pitch-linear operators -template < - int Contiguous, - int Strided -> -struct PitchLinearShape { - static int const kContiguous = Contiguous; - static int const kStrided = Strided; - static int const kCount = Contiguous * Strided; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Coordinate in pitch-linear space -struct PitchLinearCoord : public Coord<2, int> { -public: - - /// Integer-valued index - using Index = int; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// Long integer type - using LongIndex = typename Base::LongIndex; - -private: - - /// Rows dimension - static int const kContiguous = 0; - - /// Columns dimension - static int const kStrided = 1; - -public: - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - PitchLinearCoord() { } - - /// Constructs from Coord<2> - CUTLASS_HOST_DEVICE - PitchLinearCoord(Coord<2, Index> const &coord): Base(coord) { } - - /// Helper to construct from a row and column - CUTLASS_HOST_DEVICE - PitchLinearCoord(Index contiguous_, Index strided_): Base(make_Coord(contiguous_, strided_)) { } - - /// Helper to construct from a row and column based on LongIndex - CUTLASS_HOST_DEVICE - PitchLinearCoord(LongIndex contiguous_, LongIndex strided_) - : Base(make_Coord(Index(contiguous_), Index(strided_))) { } - - /// Returns the contiguous dimension - CUTLASS_HOST_DEVICE - Index const & contiguous() const { return this->at(kContiguous); } - - /// Returns the contiguous dimension - CUTLASS_HOST_DEVICE - Index & contiguous() { return this->at(kContiguous); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index const & strided() const { return this->at(kStrided); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index & strided() { return this->at(kStrided); } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - PitchLinearCoord operator+(Base const& b) const { - return PitchLinearCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - PitchLinearCoord operator-(Base const& b) const { - return PitchLinearCoord(Base::operator-(b)); - } - - CUTLASS_HOST_DEVICE - PitchLinearCoord operator-() const { - return PitchLinearCoord(-at(0), -at(1)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - PitchLinearCoord operator*(Base const& b) const { - return PitchLinearCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - PitchLinearCoord operator/(Base const& b) const { - return PitchLinearCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/platform/platform.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/platform/platform.h deleted file mode 100644 index 86ba43a4cc06d84d911d8b135babbad0338894ef..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/platform/platform.h +++ /dev/null @@ -1,953 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -/** - * \file - * \brief C++ features that may be otherwise unimplemented for CUDA device functions. - * - * This file has three components: - * - * (1) Macros: - * - Empty macro defines for C++ keywords not supported by the current - * version of C++. These simply allow compilation to proceed (but do - * not provide the added semantics). - * - \p noexcept - * - \p constexpr - * - \p nullptr - * - \p static_assert - * - * - Macro functions that we need in constant expressions because the - * C++ equivalents require constexpr compiler support. These are - * prefixed with \p __NV_STD_* - * - \p __NV_STD_MAX - * - \p __NV_STD_MIN - * - * (2) Re-implementations of STL functions and types: - * - C++ features that need the \p __device__ annotation. These are - * placed into the \p platform namespace. - * - \p abs - * - \p plus - * - \p less - * - \p greater - * - \p min - * - \p max - * - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair()) - * - * (3) Stop-gap implementations of unsupported STL functions and types: - * - STL functions and types defined by C++ 11/14/17/etc. that are not - * provided by the current version of C++. These are placed into the - * \p platform namespace - * - \p integral_constant - * - \p nullptr_t - * - \p true_type - * - \p false_type - * - \p bool_constant - * - \p enable_if - * - \p conditional - * - \p is_same - * - \p is_base_of - * - \p remove_const - * - \p remove_volatile - * - \p remove_cv - * - \p is_volatile - * - \p is_pointer - * - \p is_void - * - \p is_integral - * - \p is_floating_point - * - \p is_arithmetic - * - \p is_fundamental - * - \p is_trivially_copyable - * - \p alignment_of - * - \p aligned_storage - * - * The idea is that, as we drop support for older compilers, we can simply #define - * the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++ - * counterparts (or trivially find-and-replace their occurrences in code text). - */ - -//----------------------------------------------------------------------------- -// Dependencies -//----------------------------------------------------------------------------- -#include -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(type_traits) -#include CUDA_STD_HEADER(utility) -#include CUDA_STD_HEADER(cstddef) -#include CUDA_STD_HEADER(cstdint) -#include CUDA_STD_HEADER(limits) -#else -#include -#include -#include -#include -#include -#endif - -#if !defined(__CUDACC_RTC__) -//----------------------------------------------------------------------------- -// Include STL files that platform provides functionality for -//----------------------------------------------------------------------------- - -#include // Minimum/maximum operations -#include // nullptr_t -#include // Arithmetic operations -#include // For methods on std::pair -#include // float_round_style, float_denorm_style -#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500)) -#include // For integral constants, conditional metaprogramming, and type traits -#endif - -#include - -#endif - -//----------------------------------------------------------------------------- -// OS -//----------------------------------------------------------------------------- -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) -#define CUTLASS_OS_WINDOWS -#endif - -#if defined(__clang__) && defined(__CUDA__) -#define CUTLASS_CLANG_CUDA 1 -#endif - -/****************************************************************************** - * Macros - ******************************************************************************/ -/// std -#if !defined(CUTLASS_STL_NAMESPACE) -#if defined(__CUDACC_RTC__) -#define CUTLASS_STL_NAMESPACE cuda::std -#else -#define CUTLASS_STL_NAMESPACE std -#endif -#endif - -/// builtin_unreachable -#if !defined(CUTLASS_GCC_UNREACHABLE) -# if defined(__GNUC__) -# define CUTLASS_GCC_UNREACHABLE __builtin_unreachable() -# else -# define CUTLASS_GCC_UNREACHABLE -# endif -#endif - -//----------------------------------------------------------------------------- -// Keywords -//----------------------------------------------------------------------------- - -/// noexcept, constexpr -#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) -#ifndef noexcept -#define noexcept -#endif -#ifndef constexpr -#define constexpr -#endif -#endif - -/// nullptr -#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310)) -#ifndef nullptr -#define nullptr 0 -#endif -#endif - -/// static_assert -#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) -#ifndef static_assert -#define __platform_cat_(a, b) a##b -#define __platform_cat(a, b) __platform_cat_(a, b) -#define static_assert(__e, __m) typedef int __platform_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1] -#endif -#endif - -//----------------------------------------------------------------------------- -// Functions -//----------------------------------------------------------------------------- - -/// Select maximum(a, b) -#ifndef __NV_STD_MAX -#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a)) -#endif - -/// Select minimum(a, b) -#ifndef __NV_STD_MIN -#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a)) -#endif - -/****************************************************************************** - * Re-implementations - ******************************************************************************/ -namespace cutlass { -namespace platform { - -//----------------------------------------------------------------------------- -// Abs operations -//----------------------------------------------------------------------------- - -#if defined(__CUDACC_RTC__) -/// std::abs -CUTLASS_HOST_DEVICE constexpr int abs(int a) { - return (a < 0) ? -a : a; -} -CUTLASS_HOST_DEVICE constexpr long long abs(long long a) { - return (a < 0) ? -a : a; -} -#else -using std::abs; -#endif - -//----------------------------------------------------------------------------- -// Minimum/maximum operations -//----------------------------------------------------------------------------- - -/// std::min -template -CUTLASS_HOST_DEVICE constexpr const T& min(const T& a, const T& b) { - return (b < a) ? b : a; -} - -/// std::max -template -CUTLASS_HOST_DEVICE constexpr const T& max(const T& a, const T& b) { - return (a < b) ? b : a; -} - -#if !defined(__CUDACC_RTC__) -//----------------------------------------------------------------------------- -// Methods on std::pair -//----------------------------------------------------------------------------- - -using std::pair; - -template -CUTLASS_HOST_DEVICE constexpr bool operator==(const pair& lhs, const pair& rhs) { - return (lhs.first == rhs.first) && (lhs.second == rhs.second); -} - -template -CUTLASS_HOST_DEVICE constexpr bool operator!=(const pair& lhs, const pair& rhs) { - return (lhs.first != rhs.first) && (lhs.second != rhs.second); -} - -template -CUTLASS_HOST_DEVICE constexpr bool operator<(const pair& lhs, const pair& rhs) { - return (lhs.first < rhs.first) ? true : (rhs.first < lhs.first) ? false - : (lhs.second < rhs.second); -} - -template -CUTLASS_HOST_DEVICE constexpr bool operator<=(const pair& lhs, const pair& rhs) { - return !(rhs < lhs); -} - -template -CUTLASS_HOST_DEVICE constexpr bool operator>(const pair& lhs, const pair& rhs) { - return (rhs < lhs); -} - -template -CUTLASS_HOST_DEVICE constexpr bool operator>=(const pair& lhs, const pair& rhs) { - return !(lhs < rhs); -} - -template -CUTLASS_HOST_DEVICE std::pair make_pair(T1 t, T2 u) { - std::pair retval; - retval.first = t; - retval.second = u; - return retval; -} -#endif - -} // namespace platform - -/****************************************************************************** - * Implementations of C++ 11/14/17/... STL features - ******************************************************************************/ - -namespace platform { - -//----------------------------------------------------------------------------- -// Integral constant helper types -//----------------------------------------------------------------------------- - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) - -#else - -using std::pair; - -#endif - -using CUTLASS_STL_NAMESPACE::integral_constant; -using CUTLASS_STL_NAMESPACE::bool_constant; -using CUTLASS_STL_NAMESPACE::true_type; -using CUTLASS_STL_NAMESPACE::false_type; - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700)) - -/// std::nullptr_t -struct nullptr_t {}; - -#else - -using std::nullptr_t; - -#endif - -//----------------------------------------------------------------------------- -// Conditional metaprogramming -//----------------------------------------------------------------------------- - -using CUTLASS_STL_NAMESPACE::conditional; -using CUTLASS_STL_NAMESPACE::conditional_t; -using CUTLASS_STL_NAMESPACE::enable_if; -using CUTLASS_STL_NAMESPACE::enable_if_t; -using CUTLASS_STL_NAMESPACE::void_t; - -//----------------------------------------------------------------------------- -// Const/volatility specifiers -//----------------------------------------------------------------------------- - -using CUTLASS_STL_NAMESPACE::remove_const; -using CUTLASS_STL_NAMESPACE::remove_const_t; -using CUTLASS_STL_NAMESPACE::remove_cv; -using CUTLASS_STL_NAMESPACE::remove_cv_t; -using CUTLASS_STL_NAMESPACE::remove_reference; -using CUTLASS_STL_NAMESPACE::remove_reference_t; -using CUTLASS_STL_NAMESPACE::remove_volatile; -using CUTLASS_STL_NAMESPACE::remove_volatile_t; - -// remove_cvref and remove_cvref_t are C++20 features, -// but CUTLASS finds them useful enough to back-port. -#if defined(__cpp_lib_remove_cvref) - -using CUTLASS_STL_NAMESPACE::remove_cvref; -using CUTLASS_STL_NAMESPACE::remove_cvref_t; - -#else - -template -struct remove_cvref { - using type = remove_cv_t>; -}; - -template -using remove_cvref_t = typename remove_cvref::type; - -#endif - -//----------------------------------------------------------------------------- -// Type relationships -//----------------------------------------------------------------------------- - -using CUTLASS_STL_NAMESPACE::is_same; -using CUTLASS_STL_NAMESPACE::is_same_v; - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) - -/// Helper for std::is_base_of -template -struct is_base_of_helper { - typedef char (&yes)[1]; - typedef char (&no)[2]; - - template - struct dummy { - CUTLASS_HOST_DEVICE operator B*() const; - CUTLASS_HOST_DEVICE operator D*(); - }; - - template - CUTLASS_HOST_DEVICE static yes check(DerivedT*, T); - - CUTLASS_HOST_DEVICE static no check(BaseT*, int); - - static const bool value = sizeof(check(dummy(), int())) == sizeof(yes); -}; - -/// std::is_base_of -template -struct is_base_of - : integral_constant::type, - typename remove_cv::type>::value) || - (is_same::type, - typename remove_cv::type>::value)> {}; - -#else - -using std::is_base_of; - -#endif - -//----------------------------------------------------------------------------- -// Type properties -//----------------------------------------------------------------------------- - -using CUTLASS_STL_NAMESPACE::is_arithmetic; -using CUTLASS_STL_NAMESPACE::is_arithmetic_v; -using CUTLASS_STL_NAMESPACE::is_void; -using CUTLASS_STL_NAMESPACE::is_void_v; - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) - -/// std::is_volatile -template -struct is_volatile : false_type {}; -template -struct is_volatile : true_type {}; - -/// Helper for std::is_pointer (false specialization) -template -struct is_pointer_helper : false_type {}; - -/// Helper for std::is_pointer (true specialization) -template -struct is_pointer_helper : true_type {}; - -/// std::is_pointer -template -struct is_pointer : is_pointer_helper::type> {}; - -/// std::is_integral -template -struct is_integral : false_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template <> -struct is_integral : true_type {}; -template -struct is_integral : is_integral {}; -template -struct is_integral : is_integral {}; -template -struct is_integral : is_integral {}; - -/// std::is_floating_point -template -struct is_floating_point - : integral_constant::type>::value || - is_same::type>::value)> {}; - -/// std::is_fundamental -template -struct is_fundamental - : integral_constant::value || is_void::value || - is_same::type>::value)> {}; - -#else - -using std::is_volatile; -using std::is_pointer; -using std::is_integral; -using std::is_floating_point; -using std::is_fundamental; - -#endif - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) || \ - (defined(__GNUG__) && (__GNUC__ < 5)) - -/** - * std::is_trivially_copyable - * - * This implementation only evaluates true if T is fundamental or pointer - * - * Without help from partial template specializations provided by the user for - * a specific class or struct, this trait will never report that the specified - * class or struct is trivially-copyable ; this is always safe, - * if possibly sub-optimal. - */ -template -struct is_trivially_copyable - : integral_constant::value || is_pointer::value)> {}; - -#else - -using std::is_trivially_copyable; - -#endif - -#if (CUTLASS_CXX17_OR_LATER) - -/// std::is_unsigned_v -using CUTLASS_STL_NAMESPACE::is_integral_v; -/// std::is_unsigned_v -using CUTLASS_STL_NAMESPACE::is_unsigned_v; - -#endif - -//----------------------------------------------------------------------------- -// -//----------------------------------------------------------------------------- - -using CUTLASS_STL_NAMESPACE::declval; - -//----------------------------------------------------------------------------- -// bit_cast -//----------------------------------------------------------------------------- - -template< class To, class From > -constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept; - -template -constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept -{ - static_assert(sizeof(To) == sizeof(From), "sizes must match"); - return reinterpret_cast(src); -} - -//----------------------------------------------------------------------------- -// Convertable -//----------------------------------------------------------------------------- -using CUTLASS_STL_NAMESPACE::is_convertible; -using CUTLASS_STL_NAMESPACE::is_convertible_v; - -//----------------------------------------------------------------------------- -// Alignment and layout utilities -//----------------------------------------------------------------------------- - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) - -/// std::alignment_of -template -struct alignment_of { - struct pad { - value_t val; - char byte; - }; - - enum { value = sizeof(pad) - sizeof(value_t) }; -}; - -#else - -template -struct alignment_of : std::alignment_of {}; - -#endif - -/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */ -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; - -#if !defined(CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED) -#define CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED (__CUDACC_VER_MAJOR__ >= 13) -#endif - -#if (CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED) -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 32 }; -}; -template <> -struct alignment_of { - enum { value = 32 }; -}; -template <> -struct alignment_of { - enum { value = 32 }; -}; -template <> -struct alignment_of { - enum { value = 32 }; -}; -template <> -struct alignment_of { - enum { value = 32 }; -}; -#else -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; - -#endif - -// Specializations for volatile/const qualified types -template -struct alignment_of : alignment_of {}; -template -struct alignment_of : alignment_of {}; -template -struct alignment_of : alignment_of {}; - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) - -template -struct aligned_chunk; -template <> -struct __align__(1) aligned_chunk<1> { - uint8_t buff; -}; -template <> -struct __align__(2) aligned_chunk<2> { - uint16_t buff; -}; -template <> -struct __align__(4) aligned_chunk<4> { - uint32_t buff; -}; -template <> -struct __align__(8) aligned_chunk<8> { - uint32_t buff[2]; -}; -template <> -struct __align__(16) aligned_chunk<16> { - uint32_t buff[4]; -}; -template <> -struct __align__(32) aligned_chunk<32> { - uint32_t buff[8]; -}; -template <> -struct __align__(64) aligned_chunk<64> { - uint32_t buff[16]; -}; -template <> -struct __align__(128) aligned_chunk<128> { - uint32_t buff[32]; -}; -template <> -struct __align__(256) aligned_chunk<256> { - uint32_t buff[64]; -}; -template <> -struct __align__(512) aligned_chunk<512> { - uint32_t buff[128]; -}; -template <> -struct __align__(1024) aligned_chunk<1024> { - uint32_t buff[256]; -}; -template <> -struct __align__(2048) aligned_chunk<2048> { - uint32_t buff[512]; -}; -template <> -struct __align__(4096) aligned_chunk<4096> { - uint32_t buff[1024]; -}; - -/// std::aligned_storage -template -struct aligned_storage { - typedef aligned_chunk type[Len / sizeof(aligned_chunk)]; -}; - -#else - -using std::aligned_storage; - -#endif - -#if !defined(__CUDACC_RTC__) -/// Default deleter -template -struct default_delete { - void operator()(T* ptr) const { delete ptr; } -}; - -/// Partial specialization for deleting array types -template -struct default_delete { - void operator()(T* ptr) const { delete[] ptr; } -}; - -/// std::unique_ptr -template > -class unique_ptr { - public: - typedef T* pointer; - typedef T element_type; - typedef Deleter deleter_type; - - private: - /// Pointer to memory - pointer _ptr; - - /// Deleter - deleter_type _deleter; - - public: - unique_ptr() : _ptr(nullptr) {} - unique_ptr(pointer p) : _ptr(p) {} - - ~unique_ptr() { - if (_ptr) { - _deleter(_ptr); - } - } - /// Returns a pointer to the managed object or nullptr if no object is owned. - pointer get() const noexcept { return _ptr; } - - /// Releases ownership of the managed object, if any - pointer release() noexcept { - pointer p(_ptr); - _ptr = nullptr; - return p; - } - - /// Replaces the managed object, deleting the old object. - void reset(pointer p = pointer()) noexcept { - pointer old_ptr = _ptr; - _ptr = p; - if (old_ptr != nullptr) { - get_deleter()(old_ptr); - } - } - - /// Swaps the managed objects with *this and another unique_ptr - void swap(unique_ptr& other) noexcept { std::swap(_ptr, other._ptr); } - - /// Returns the deleter object - Deleter& get_deleter() noexcept { return _deleter; } - - /// Returns the deleter object - Deleter const& get_deleter() const noexcept { return _deleter; } - - /// Checks whether an object is owned - operator bool() const noexcept { return _ptr != nullptr; } - - /// Dereferences the unique_ptr - T& operator*() const { return *_ptr; } - - /// Returns a pointer to the managed object - pointer operator->() const noexcept { return _ptr; } - - /// Array access to managed object - T& operator[](size_t i) const { return _ptr[i]; } -}; - -/// Specializes the swap algorithm -template -void swap(unique_ptr& lhs, unique_ptr& rhs) noexcept { - lhs.swap(rhs); -} -#endif - -/// std::numeric_limits -template -struct numeric_limits; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr int32_t lowest() noexcept { return -2147483647 - 1;} - CUTLASS_HOST_DEVICE - static constexpr int32_t max() noexcept { return 2147483647;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr int16_t lowest() noexcept { return -32768;} - CUTLASS_HOST_DEVICE - static constexpr int16_t max() noexcept { return 32767;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr int8_t lowest() noexcept { return -128;} - CUTLASS_HOST_DEVICE - static constexpr int8_t max() noexcept { return 127;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr uint32_t lowest() noexcept { return 0;} - CUTLASS_HOST_DEVICE - static constexpr uint32_t max() noexcept { return 4294967295U;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr uint16_t lowest() noexcept { return 0;} - CUTLASS_HOST_DEVICE - static constexpr uint16_t max() noexcept { return 65535U;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr uint8_t lowest() noexcept { return 0;} - CUTLASS_HOST_DEVICE - static constexpr uint8_t max() noexcept { return 255U;} - static constexpr bool is_integer = true; - static constexpr bool has_infinity = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE - static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} - CUTLASS_HOST_DEVICE - static constexpr float max() noexcept { return bit_cast(0x7f7fffff);} - static constexpr bool is_integer = false; - static constexpr bool has_infinity = true; -}; - -/// Returns a value that curries the `std::maximum()` function into the identity -/// function. No value will compare < than this value. -template -constexpr T identity_for_maximum() { - if constexpr (numeric_limits::has_infinity) { - return -numeric_limits::infinity(); - } else { - return numeric_limits::lowest(); - } -} - -/// Returns a value that curries the `std::minimum()` function into the identity -/// function. No value will compare > than this value. -template -constexpr T identity_for_minimum() { - if constexpr (numeric_limits::has_infinity) { - return numeric_limits::infinity(); - } else { - return numeric_limits::max(); - } -} - -/// std::float_round_style -using CUTLASS_STL_NAMESPACE::float_round_style; -using CUTLASS_STL_NAMESPACE::round_indeterminate; -using CUTLASS_STL_NAMESPACE::round_toward_zero; -using CUTLASS_STL_NAMESPACE::round_to_nearest; -using CUTLASS_STL_NAMESPACE::round_toward_infinity; -using CUTLASS_STL_NAMESPACE::round_toward_neg_infinity; - -/// std::float_denorm_style -using CUTLASS_STL_NAMESPACE::float_denorm_style; -using CUTLASS_STL_NAMESPACE::denorm_indeterminate; -using CUTLASS_STL_NAMESPACE::denorm_absent; -using CUTLASS_STL_NAMESPACE::denorm_present; - -} // namespace platform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/predicate_vector.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/predicate_vector.h deleted file mode 100644 index c3867c570340fd41480c7806456d269eed0b1189..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/predicate_vector.h +++ /dev/null @@ -1,545 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines container classes and iterators for managing a statically sized vector - of boolean predicates. -*/ -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#endif - -#include CUDA_STD_HEADER(cassert) - -#include "cutlass/platform/platform.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/*!@defgroup predicate_vector_concept Predicate Vector Concept -@{ - -Implementations of \ref predicate_vector_concept contain an ordered set of boolean predicates which -may be used as conditionals in other device-side operations. Both random access and iterators -offering sequential access are provided. - -@par Predicate Vector - A \ref predicate_vector_concept satisfies the following expressions - - at(int idx) - returns the value of the indexed predicate - - set(int idx, bool value) - sets the value of the indexed predicate - - begin() - returns a \ref predicate_iterator_concept pointing to the first predicate - -@} -*/ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/*!@defgroup predicate_iterator_concept Predicate Iterator Concept -@{ - -Implementations of \ref predicate_iterator_concept enables accessing and traversing elements of a -bit vector. - -@par Const Predicate Iterator - A const \ref predicate_iterator_concept satisfies the following expressions - - ++it increments the iterator to the next predicate - - *it returns the value of the currently pointed-to predicate - -@par Mutable Predicate Iterator - A \ref predicate_iterator_concept that is non-const also satisfies the following expressions - - it.set(bool value) sets the value of the currently pointed-to predicate - -@} -*/ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept -@{ - -Implementations of \ref predicate_tile_adapter provide a mapping between a the elements of a \ref -tile_traits_concept and a \ref predicate_vector_concept. - -@par Predicate Tile Adapter - A \ref predicate_tile_adapter satisfies the following expressions - - at(int d, int h, int w, int c) - returns the value of a predicate corresponding to the - access (d, h, w, c) within the tile. - -@} -*/ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically sized array of bits implementing @concept{predicate_vector_concept}. -template < - /// Number of predicates contained in predicate vector - int kPredicates_, - /// Number of predicates contained in each byte of internal storage - int kPredicatesPerByte_ = 4, - /// Location of first predicate within byte of internal storage - int kPredicateStart_ = 0> -struct PredicateVector { - /// Number of bits stored by the PredicateVector - static constexpr int kPredicates = kPredicates_; - - /// Number of bits stored within each byte of the predicate bit vector - static constexpr int kPredicatesPerByte = kPredicatesPerByte_; - - /// First bit within each byte containing predicates - static constexpr int kPredicateStart = kPredicateStart_; - - // Make sure no one tries to put more than 8 bits in a byte :) - static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); - // Make sure the "offsetted" bits fit in one byte. - static_assert(kPredicateStart + kPredicatesPerByte <= 8, - "The offsetted predicates must fit within an actual byte."); - - /// Storage type of individual elements - typedef uint32_t Storage; - - /// Number of bytes needed - static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; - - /// Number of storage elements needed - static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); - - /// The byte mask corresponding to predicates - static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); - - private: - // - // Data members - // - - /// Words of bit vector - Storage storageData[kWordCount]; - - // - // Methods - // - - /// Computes the word and bit corresponding to a logical predicate index - CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const { - CUTLASS_ASSERT(idx < kPredicates); - - int byte = (idx / kPredicatesPerByte); - int bit_offset = (idx % kPredicatesPerByte); - - word = byte / sizeof(Storage); - int byte_offset = (byte % sizeof(Storage)); - - bit = byte_offset * 8 + bit_offset + kPredicateStart; - } - - /// Returns word mask. - CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() { - Storage mask(0); - CUTLASS_PRAGMA_UNROLL - for (size_t byte = 0; byte < sizeof(Storage); ++byte) { - mask |= (kByteMask << (byte * 8)); - } - return mask; - } - - /// Returns mask of last word. - CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() { - Storage mask(0); - CUTLASS_PRAGMA_UNROLL - for (int byte = 0; byte < kBytes % sizeof(Storage); ++byte) { - mask |= (kByteMask << (byte * 8)); - } - return mask; - } - - /// Accesses a given word with optional assertions - CUTLASS_HOST_DEVICE Storage &storage(int word) { - CUTLASS_ASSERT(word < kWordCount); - return storageData[word]; - } - - /// Accesses a given word with optional assertions - CUTLASS_HOST_DEVICE Storage const &storage(int word) const { - CUTLASS_ASSERT(word < kWordCount); - return storageData[word]; - } - - public: - // - // Iterator - // - - /** - * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential - * read and write access to predicates. - * @concept{predicate_iterator_concept} - */ - class Iterator { - /// Reference to PredicateVector instance - PredicateVector &vec_; - - /// Index into PredicateVector - int bit_; - - public: - /// Copy constructor - CUTLASS_HOST_DEVICE - Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {} - - /// Constructs an iterator from a PredicateVector - CUTLASS_HOST_DEVICE - Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {} - - /// Pre-increment - CUTLASS_HOST_DEVICE - Iterator &operator++() { - ++bit_; - return *this; - } - - /// Increment - CUTLASS_HOST_DEVICE - Iterator &operator+=(int offset) { - bit_ += offset; - return *this; - } - - /// Pre-decrement - CUTLASS_HOST_DEVICE - Iterator &operator--() { - --bit_; - return *this; - } - - /// Decrement - CUTLASS_HOST_DEVICE - Iterator &operator-=(int offset) { - bit_ -= offset; - return *this; - } - - /// Post-increment - CUTLASS_HOST_DEVICE - Iterator operator++(int) { - Iterator ret(*this); - ret.bit_++; - return ret; - } - - /// Post-decrement - CUTLASS_HOST_DEVICE - Iterator operator--(int) { - Iterator ret(*this); - ret.bit_--; - return ret; - } - - /// Iterator advances by some amount - CUTLASS_HOST_DEVICE - Iterator operator+(int offset) { - Iterator ret(*this); - ret.bit_ += offset; - return ret; - } - - /// Iterator recedes by some amount - CUTLASS_HOST_DEVICE - Iterator operator-(int offset) { - ConstIterator ret(*this); - ret.bit_ -= offset; - return ret; - } - - /// Returns true if iterators point to the same bit - CUTLASS_HOST_DEVICE - bool operator==(Iterator const &it) const { return bit_ == it.bit_; } - - /// Returns false if iterators point to the same bit - CUTLASS_HOST_DEVICE - bool operator!=(Iterator const &it) const { return bit_ != it.bit_; } - - /// Gets the bit at the pointed to location - CUTLASS_HOST_DEVICE - bool get() { return vec_.at(bit_); } - - /// Gets the bit at the pointed to location - CUTLASS_HOST_DEVICE - bool at() const { return vec_.at(bit_); } - - /// Dereferences iterator - CUTLASS_HOST_DEVICE - bool operator*() const { return at(); } - - /// Sets the bit at the pointed to location - CUTLASS_HOST_DEVICE - void set(bool value = true) { vec_.set(bit_, value); } - }; - - /** - * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential - * read and write access to predicates. - * @concept{predicate_iterator_concept} - */ - class ConstIterator { - /// Reference to PredicateVector instance - PredicateVector const &vec_; - - /// Index into PredicateVector - int bit_; - - public: - /// Copy constructor - CUTLASS_HOST_DEVICE - ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {} - - /// Constructs an iterator from a PredicateVector - CUTLASS_HOST_DEVICE - ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {} - - /// Pre-increment - CUTLASS_HOST_DEVICE - ConstIterator &operator++() { - ++bit_; - return *this; - } - - /// Increment - CUTLASS_HOST_DEVICE - ConstIterator &operator+=(int offset) { - bit_ += offset; - return *this; - } - - /// Pre-decrement - CUTLASS_HOST_DEVICE - ConstIterator &operator--() { - --bit_; - return *this; - } - - /// Decrement - CUTLASS_HOST_DEVICE - ConstIterator &operator-=(int offset) { - bit_ -= offset; - return *this; - } - - /// Post-increment - CUTLASS_HOST_DEVICE - ConstIterator operator++(int) { - ConstIterator ret(*this); - ret.bit_++; - return ret; - } - - /// Post-decrement - CUTLASS_HOST_DEVICE - ConstIterator operator--(int) { - ConstIterator ret(*this); - ret.bit_--; - return ret; - } - - /// Iterator advances by some amount - CUTLASS_HOST_DEVICE - ConstIterator operator+(int offset) { - ConstIterator ret(*this); - ret.bit_ += offset; - return ret; - } - - /// Iterator recedes by some amount - CUTLASS_HOST_DEVICE - ConstIterator operator-(int offset) { - ConstIterator ret(*this); - ret.bit_ -= offset; - return ret; - } - - /// Returns true if iterators point to the same bit - CUTLASS_HOST_DEVICE - bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; } - - /// Returns false if iterators point to the same bit - CUTLASS_HOST_DEVICE - bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; } - - /// Gets the bit at the pointed to location - CUTLASS_HOST_DEVICE - bool get() { return vec_.at(bit_); } - - /// Gets the bit at the pointed to location - CUTLASS_HOST_DEVICE - bool at() const { return vec_.at(bit_); } - - /// Dereferences iterator - CUTLASS_HOST_DEVICE - bool operator*() const { return at(); } - }; - - /// Iterator that always returns true - struct TrivialIterator { - /// Constructor - CUTLASS_HOST_DEVICE - TrivialIterator() {} - - /// Copy constructor - CUTLASS_HOST_DEVICE - TrivialIterator(Iterator const &it) {} - - /// Constructs an iterator from a PredicateVector - CUTLASS_HOST_DEVICE - TrivialIterator(PredicateVector const &_vec) {} - - /// Pre-increment - CUTLASS_HOST_DEVICE - TrivialIterator &operator++() { return *this; } - - /// Post-increment - CUTLASS_HOST_DEVICE - TrivialIterator operator++(int) { return *this; } - - /// Dereferences iterator - CUTLASS_HOST_DEVICE - bool operator*() const { return true; } - }; - - public: - // - // Methods - // - - /// Initialize the predicate vector - CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); } - - /// Fills all predicates with a given value - CUTLASS_HOST_DEVICE void fill(bool value = true) { - Storage item = (value ? ~Storage(0) : Storage(0)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kWordCount; ++i) { - storage(i) = item; - } - } - - /// Clears all predicates - CUTLASS_HOST_DEVICE void clear() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kWordCount; ++i) { - storage(i) = 0; - } - } - - /// Sets all predicates to true - CUTLASS_HOST_DEVICE void enable() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kWordCount; ++i) { - storage(i) = ~Storage(0); - } - } - - /// Accesses a bit within the predicate vector. - CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); } - - /// Accesses a bit within the predicate vector. - CUTLASS_HOST_DEVICE bool at(int idx) const { - int bit, word; - computeStorageOffset(word, bit, idx); - - return ((storage(word) >> bit) & 1); - } - - /// Set a bit within the predicate vector. - CUTLASS_HOST_DEVICE void set(int idx, bool value = true) { - int bit, word; - computeStorageOffset(word, bit, idx); - - Storage disable_mask = (~(Storage(1) << bit)); - Storage enable_mask = (Storage(value) << bit); - - storage(word) = ((storage(word) & disable_mask) | enable_mask); - } - - /// Computes the intersection of two identical predicate vectors. - CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kWordCount; ++i) { - storage(i) = (storage(i) & predicates.storage(i)); - } - return *this; - } - - /// Computes the union of two identical predicate vectors. - CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kWordCount; ++i) { - storage(i) = (storage(i) | predicates.storage(i)); - } - return *this; - } - - /// Returns true if entire predicate array is zero. - CUTLASS_HOST_DEVICE bool is_zero() const { - constexpr Storage mask = computeWordMask(); - Storage result = 0; - CUTLASS_PRAGMA_UNROLL - for (int word = 0; word < kWordCount - 1; ++word) { - result |= (storage(word) & mask); - } - constexpr Storage last_word_mask = computeLastWordMask(); - result |= (storage(kWordCount - 1) & last_word_mask); - - return result == 0; - } - - /// Returns an iterator to the start of the bit vector - CUTLASS_DEVICE - Iterator begin() { return Iterator(*this); } - - /// Returns an iterator - CUTLASS_DEVICE - Iterator end() { return Iterator(*this, kPredicates); } - - /// Returns a ConstIterator - CUTLASS_DEVICE - ConstIterator const_begin() const { return ConstIterator(*this); } - - /// Returns a ConstIterator - CUTLASS_DEVICE - ConstIterator const_end() const { return ConstIterator(*this, kPredicates); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/quaternion.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/quaternion.h deleted file mode 100644 index 48ca3628777d5eeca1582ef2703ee01923903f26..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/quaternion.h +++ /dev/null @@ -1,752 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a densely packed quaternion object intended for storing data in registers and - executing quaternion operations within a CUDA or host thread. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/functional.h" -#include "cutlass/array.h" -#include "cutlass/real.h" -#include "cutlass/coord.h" -#include "cutlass/matrix.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/vector.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Quaternion: xi + yj + zk + w -template < - typename Element_ = float ///< element type -> -class Quaternion : public Array { -public: - - /// Logical rank of tensor index space - static int const kRank = 1; - - /// Number of elements - static int const kExtent = 4; - - /// Base class is a four-element array - using Base = Array; - - /// Element type - using Element = typename Base::Element; - - /// Reference type to an element - using Reference = typename Base::reference; - - /// Index type - using Index = int; - - /// Quaternion storage - imaginary part - static int const kX = 0; - - /// Quaternion storage - imaginary part - static int const kY = 1; - - /// Quaternion storage - imaginary part - static int const kZ = 2; - - /// Quaternion storage - real part - static int const kW = 3; - -public: - - // - // Methods - // - - /// Constructs a quaternion q = 0 - CUTLASS_HOST_DEVICE - Quaternion() { - Base::at(kX) = Element(); - Base::at(kY) = Element(); - Base::at(kZ) = Element(); - Base::at(kW) = Element(); - } - - /// Constructs a quaternion q = w + 0*i + 0*j + 0*k - CUTLASS_HOST_DEVICE - Quaternion( - Element w_ - ) { - Base::at(kX) = Element(); - Base::at(kY) = Element(); - Base::at(kZ) = Element(); - Base::at(kW) = w_; - } - - /// Constructs a quaternion q = w + x*i + y*j + z*k - CUTLASS_HOST_DEVICE - Quaternion( - Element x_, - Element y_, - Element z_, - Element w_ - ) { - Base::at(kX) = x_; - Base::at(kY) = y_; - Base::at(kZ) = z_; - Base::at(kW) = w_; - } - - /// Constructs a quaternion from a vector representing the imaginary part and a real number - CUTLASS_HOST_DEVICE - Quaternion( - Matrix3x1 const &imag_, - Element w_ = Element() - ) { - Base::at(kX) = imag_[0]; - Base::at(kY) = imag_[1]; - Base::at(kZ) = imag_[2]; - Base::at(kW) = w_; - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference at(Index idx) const { - return Base::at(idx); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference at(Index idx) { - return Base::at(idx); - } - - /// Accesses the x element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Element x() const { - return Base::at(kX); - } - - /// Accesses the x element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Reference x() { - return Base::at(kX); - } - - /// Accesses the y element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Element y() const { - return Base::at(kY); - } - - /// Accesses the y element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Reference y() { - return Base::at(kY); - } - - /// Accesses the z element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Element z() const { - return Base::at(kZ); - } - - /// Accesses the z element of the imaginary part of the quaternion - CUTLASS_HOST_DEVICE - Reference z() { - return Base::at(kZ); - } - - /// Accesses the real part of the quaternion - CUTLASS_HOST_DEVICE - Element w() const { - return Base::at(kW); - } - - /// Accesses the real part of the quaternion - CUTLASS_HOST_DEVICE - Reference w() { - return Base::at(kW); - } - - /// Returns the pure imaginary part of the quaternion as a 3-vector - CUTLASS_HOST_DEVICE - Matrix3x1 pure() const { - return Matrix3x1(x(), y(), z()); - } - - /// Returns a quaternion representation of a spatial rotation given a unit-length axis and - /// a rotation in radians. - CUTLASS_HOST_DEVICE - static Quaternion rotation( - Matrix3x1 const &axis_unit, ///< axis of rotation (assumed to be unit length) - Element theta) { ///< angular rotation in radians - - Element s = fast_sin(theta / Element(2)); - - return Quaternion( - s * axis_unit[0], - s * axis_unit[1], - s * axis_unit[2], - fast_cos(theta / Element(2)) - ); - } - - /// Returns a quaternion representation of a spatial rotation represented as a - /// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians - CUTLASS_HOST_DEVICE - static Quaternion rotation( - Element r_x, - Element r_y, - Element r_z, - Element theta) { ///< angular rotation in radians - - return rotation({r_x, r_y, r_z}, theta); - } - - /// Geometric rotation of a 3-element vector - CUTLASS_HOST_DEVICE - Matrix3x1 rotate(Matrix3x1 const &rhs) const { - return (*this * Quaternion(rhs, 0) * reciprocal(*this)).pure(); - } - - /// Inverse rotation operation - CUTLASS_HOST_DEVICE - Matrix3x1 rotate_inv(Matrix3x1 const &rhs) const { - return (reciprocal(*this) * Quaternion(rhs, 0) * *this).pure(); - } - - /// Rotates a 3-vector assuming this is a unit quaternion (a spinor) - CUTLASS_HOST_DEVICE - Matrix3x1 spinor(Matrix3x1 const &rhs) const { - return (*this * Quaternion(rhs, 0) * conj(*this)).pure(); - } - - /// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor) - CUTLASS_HOST_DEVICE - Matrix3x1 spinor_inv(Matrix3x1 const &rhs) const { - return (conj(*this) * Quaternion(rhs, 0) * *this).pure(); - } - - /// In-place addition - template - CUTLASS_HOST_DEVICE - Quaternion &operator+=(Quaternion const &rhs) { - *this = (*this + rhs); - return *this; - } - - /// In-place subtraction - template - CUTLASS_HOST_DEVICE - Quaternion &operator-=(Quaternion const &rhs) { - *this = (*this - rhs); - return *this; - } - - /// In-place multiplication - template - CUTLASS_HOST_DEVICE - Quaternion &operator*=(Quaternion const &rhs) { - *this = (*this * rhs); - return *this; - } - - /// Scalar multiplication - template - CUTLASS_HOST_DEVICE - Quaternion &operator*=(Element s) { - *this = (*this * s); - return *this; - } - - /// In-place Division - template - CUTLASS_HOST_DEVICE - Quaternion &operator/=(Quaternion const &rhs) { - *this = (*this / rhs); - return *this; - } - - /// In-place Division - template - CUTLASS_HOST_DEVICE - Quaternion &operator/=(Element s) { - *this = (*this / s); - return *this; - } - - /// Computes a 3x3 rotation matrix (row-major representation) - CUTLASS_HOST_DEVICE - Matrix3x3 as_rotation_matrix_3x3() const { - Matrix3x3 m( - w() * w() + x() * x() - y() * y() - z() * z(), - 2 * x() * y() - 2 * w() * z(), - 2 * x() * z() + 2 * w() * y(), - - 2 * x() * y() + 2 * w() * z(), - w() * w() - x() * x() + y() * y() - z() * z(), - 2 * y() * z() - 2 * w() * x(), - - 2 * x() * z() - 2 * w() * y(), - 2 * y() * z() + 2 * w() * x(), - w() * w() - x() * x() - y() * y() + z() * z() - ); - return m; - } - - /// Computes a 4x4 rotation matrix (row-major representation) - CUTLASS_HOST_DEVICE - Matrix4x4 as_rotation_matrix_4x4() const { - Matrix4x4 m = Matrix4x4::identity(); - m.set_slice_3x3(as_rotation_matrix_3x3()); - return m; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Constructs a quaternion that is non-zero only in its real element. -template -CUTLASS_HOST_DEVICE -Quaternion make_Quaternion( - Element w) { ///< real part - - return Quaternion(w); -} - -/// Constructs a quaternion from a vector and real -template -CUTLASS_HOST_DEVICE -Quaternion make_Quaternion( - Matrix3x1 const &imag, ///< imaginary party as a vector - Element w) { ///< real part - - return Quaternion(imag, w); -} - -/// Constructs a quaternion from a unit-length rotation axis and a rotation -/// angle in radians -template -CUTLASS_HOST_DEVICE -Quaternion make_QuaternionRotation( - Matrix3x1 const &axis_unit, ///< rotation axis (unit-length) - Element w) { ///< rotation angle in radians - - return Quaternion::rotation(axis_unit, w); -} - -/// Constructs a quaternion q = xi + yj + zk + w -template -CUTLASS_HOST_DEVICE -Quaternion make_Quaternion(Element x, Element y, Element z, Element w) { - return Quaternion(x, y, z, w); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the real part of the quaternion number -template -CUTLASS_HOST_DEVICE -Element const &real(Quaternion const &q) { - return q.w(); -} - -/// Returns the real part of the quaternion number -template -CUTLASS_HOST_DEVICE -Element &real(Quaternion &q) { - return q.w(); -} - -/// Returns the magnitude of the quaternion number -template -CUTLASS_HOST_DEVICE -Element abs(Quaternion const &q) { - return fast_sqrt(norm(q)); -} - -/// Quaternion conjugate -template -CUTLASS_HOST_DEVICE -Quaternion conj(Quaternion const &q) { - return make_Quaternion( - -q.x(), - -q.y(), - -q.z(), - q.w() - ); -} - -/// Computes the squared magnitude of the quaternion -template -CUTLASS_HOST_DEVICE -Element norm(Quaternion const &q) { - return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w(); -} - -/// Quaternion reciprocal -template -CUTLASS_HOST_DEVICE -Quaternion reciprocal(Quaternion const &q) { - - Element nsq = norm(q); - - return make_Quaternion( - -q.x() / nsq, - -q.y() / nsq, - -q.z() / nsq, - q.w() / nsq - ); -} - -/// Returns a unit-length quaternion -template -CUTLASS_HOST_DEVICE -Quaternion unit(Quaternion const &q) { - - Element rcp_mag = Element(1) / abs(q); - - return make_Quaternion( - q.x() * rcp_mag, - q.y() * rcp_mag, - q.z() * rcp_mag, - q.w() * rcp_mag - ); -} - -/// Quaternion exponential -template -CUTLASS_HOST_DEVICE -Quaternion exp(Quaternion const &q) { - - Element exp_ = fast_exp(q.w()); - Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); - Element sin_norm = fast_sin(imag_norm); - - return make_Quaternion( - exp_ * q.x() * sin_norm / imag_norm, - exp_ * q.y() * sin_norm / imag_norm, - exp_ * q.z() * sin_norm / imag_norm, - exp_ * fast_cos(imag_norm) - ); -} - -/// Quaternion natural logarithm -template -CUTLASS_HOST_DEVICE -Quaternion log(Quaternion const &q) { - - Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); - Element s = fast_acos(q.w() / abs(q)) / v; - - return make_Quaternion( - q.x() * s, - q.y() * s, - q.z() * s, - fast_log(q.w()) - ); -} - -/// Gets the rotation angle from a unit-length quaternion -template -CUTLASS_HOST_DEVICE -Element get_rotation_angle(Quaternion const &q_unit) { - return fast_acos(q_unit.w()) * Element(2); -} - -/// Gets the rotation axis from a unit-length quaternion -template -CUTLASS_HOST_DEVICE -Matrix3x1 get_rotation_axis(Quaternion const &q_unit) { - return q_unit.pure().unit(); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Equality operator -template -CUTLASS_HOST_DEVICE -bool operator==(Quaternion const &lhs, Quaternion const &rhs) { - return lhs.x() == rhs.x() && - lhs.y() == rhs.y() && - lhs.z() == rhs.z() && - lhs.w() == rhs.w(); -} - -/// Inequality operator -template -CUTLASS_HOST_DEVICE -bool operator!=(Quaternion const &lhs, Quaternion const &rhs) { - return !(lhs == rhs); -} - -/// Quaternion scalar multiplication -template -CUTLASS_HOST_DEVICE -Quaternion operator*(Quaternion q, Element s) { - return make_Quaternion( - q.x() * s, - q.y() * s, - q.z() * s, - q.w() * s - ); -} - -/// Quaternion scalar multiplication -template -CUTLASS_HOST_DEVICE -Quaternion operator*(Element s, Quaternion const &q) { - return make_Quaternion( - s * q.x(), - s * q.y(), - s * q.z(), - s * q.w() - ); -} - -/// Quaternion scalar division -template -CUTLASS_HOST_DEVICE -Quaternion operator/(Quaternion const &q, Element s) { - return make_Quaternion( - q.x() / s, - q.y() / s, - q.z() / s, - q.w() / s - ); -} - -/// Quaternion unary negation -template -CUTLASS_HOST_DEVICE -Quaternion operator-(Quaternion const &q) { - return make_Quaternion( - -q.x(), - -q.y(), - -q.z(), - -q.w() - ); -} - -/// Quaternion addition -template -CUTLASS_HOST_DEVICE -Quaternion operator+(Quaternion const &lhs, Quaternion const &rhs) { - return make_Quaternion( - lhs.x() + rhs.x(), - lhs.y() + rhs.y(), - lhs.z() + rhs.z(), - lhs.w() + rhs.w() - ); -} - -/// Quaternion subtraction -template -CUTLASS_HOST_DEVICE -Quaternion operator-(Quaternion const &lhs, Quaternion const &rhs) { - return make_Quaternion( - lhs.x() - rhs.x(), - lhs.y() - rhs.y(), - lhs.z() - rhs.z(), - lhs.w() - rhs.w() - ); -} - -/// Quaternion product -template -CUTLASS_HOST_DEVICE -Quaternion operator*(Quaternion const &lhs, Quaternion const &rhs) { - return make_Quaternion( - lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(), - lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(), - lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(), - lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z() - ); -} - -/// Quaternion division -template -CUTLASS_HOST_DEVICE -Quaternion operator/(Quaternion const &lhs, Quaternion const &rhs) { - return lhs * reciprocal(rhs); -} - -/// Quaternion scalar division -template -CUTLASS_HOST_DEVICE -Quaternion operator/(Element s, Quaternion const &q) { - return s * reciprocal(q); -} - -/// Comparison -template -CUTLASS_HOST_DEVICE -bool operator<(Quaternion const &lhs, Quaternion const &rhs) { - return true; -} - -/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -/// a reciprocal. -template -CUTLASS_HOST_DEVICE -Matrix3x1 spinor_rotation( - Quaternion const &spinor, /// unit-length quaternion - Matrix3x1 const &rhs) { /// arbitrary 3-vector - - return (spinor * Quaternion(rhs, 0) * conj(spinor)).pure(); -} - -/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -/// a reciprocal. -template -CUTLASS_HOST_DEVICE -Matrix3x1 spinor_rotation_inv( - Quaternion const &spinor, /// unit-length quaternion - Matrix3x1 const &rhs) { /// arbitrary 3-vector - - return (conj(spinor) * Quaternion(rhs, 0) * spinor).pure(); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Quaternion-valued type. -template -struct RealType< Quaternion > { - using Type = T; - - /// Number of elements - static int const kExtent = Quaternion::kExtent; - -CUTLASS_HOST_DEVICE - static Quaternion from_real(double x) { - return Quaternion(static_cast(x)); - } -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Factories -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_HOST_DEVICE -cutlass::Quaternion from_real >(double r) { - return cutlass::Quaternion(half_t(r)); -} - -template <> -CUTLASS_HOST_DEVICE -cutlass::Quaternion from_real >(double r) { - return cutlass::Quaternion(float(r)); -} - -template <> -CUTLASS_HOST_DEVICE -cutlass::Quaternion from_real >(double r) { - return cutlass::Quaternion(r); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// functional.h numeric specializations -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct multiplies> { - CUTLASS_HOST_DEVICE - Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { - lhs = lhs * rhs; - return lhs; - } -}; - -/// Squares with optional conversion -template -struct magnitude_squared, Output> { - CUTLASS_HOST_DEVICE - Output operator()(Quaternion lhs) const { - multiplies mul_op; - - Output y_w = Output(lhs.w()); - Output y_x = Output(lhs.x()); - Output y_y = Output(lhs.y()); - Output y_z = Output(lhs.z()); - - return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ - mul_op(y_z, y_z); - } -}; - -template -struct multiply_add, Quaternion, Quaternion> { - CUTLASS_HOST_DEVICE - Quaternion operator()( - Quaternion const &a, - Quaternion const &b, - Quaternion const &c) const { - - T x = c.x(); - T y = c.y(); - T z = c.z(); - T w = c.w(); - - x += a.w() * b.x(); - x += b.w() * a.x(); - x += a.y() * b.z(); - x += -a.z() * b.y(), - - y += a.w() * b.y(); - y += b.w() * a.y(); - y += a.z() * b.x(); - y += -a.x() * b.z(); - - z += a.w() * b.z(); - z += b.w() * a.z(); - z += a.x() * b.y(); - z += -a.y() * b.x(); - - w += a.w() * b.w(); - w += -a.x() * b.x(); - w += -a.y() * b.y(); - w += -a.z() * b.z(); - - return cutlass::make_Quaternion(x, y, z, w); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/real.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/real.h deleted file mode 100644 index cfca386610d5b6412b98d942c45ca28c2129ec1f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/real.h +++ /dev/null @@ -1,63 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/** - \file - \brief This class provides helpers to support real<> and complex<> types in generic code. -*/ - -#pragma once - -#include // CUTLASS_DEVICE - -namespace cutlass { - -/// Used to determine the real-valued underlying type of a numeric type T. -template -struct RealType { - using Type = T; - - /// Number of elements - static int const kExtent = 1; - -CUTLASS_HOST_DEVICE - static T from_real(double x) { - return static_cast(x); - } -}; - -template -CUTLASS_HOST_DEVICE -static T from_real(double r) { - return T(r); -} - - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h deleted file mode 100644 index 92b57aae26e22cc7a5859568882a9661f022c5a7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h +++ /dev/null @@ -1,232 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over densely packed tensors in global memory -*/ - -#pragma once - -#include "cutlass/device_kernel.h" -#include "cutlass/reduction/kernel/reduce_split_k.h" -#include "cutlass/cuda_host_adapter.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ReductionKernel_ -> -class ReduceSplitK { -public: - using ReductionKernel = ReductionKernel_; - - using Shape = typename ReductionKernel::Shape; - using ReductionOp = typename ReductionKernel::ReductionOp; - using OutputOp = typename ReductionKernel::OutputOp; - - using ElementWorkspace = typename ReductionKernel::ElementWorkspace; - using ElementAccumulator = typename ReductionKernel::ElementAccumulator; - using ElementOutput = typename ReductionKernel::ElementOutput; - - using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; - using OutputTensorRef = typename ReductionKernel::OutputTensorRef; - - using StrideIndex = typename ReductionKernel::StrideIndex; - - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - /// Argument structure - struct Arguments { - - // - // Data members - // - - MatrixCoord problem_size{0,0}; - int partitions{1}; - size_t partition_stride{0}; - WorkspaceTensorRef workspace{}; - OutputTensorRef destination{}; - OutputTensorRef source{}; - typename OutputOp::Params output{}; - typename ReductionOp::Params reduction{}; - - // - // Methods - // - - /// Default ctor - Arguments() = default; - - CUTLASS_HOST_DEVICE - Arguments( - MatrixCoord const & problem_size - ): - problem_size(problem_size) { } - - CUTLASS_HOST_DEVICE - Arguments( - MatrixCoord problem_size_, - int partitions_, - size_t partition_stride_, - WorkspaceTensorRef workspace_, - OutputTensorRef destination_, - OutputTensorRef source_, - typename OutputOp::Params output_ = typename OutputOp::Params(), - typename ReductionOp::Params reduction_ = typename ReductionOp::Params() - ): - problem_size(problem_size_), - partitions(partitions_), - partition_stride(partition_stride_), - workspace(workspace_), - destination(destination_), - source(source_), - output(output_), - reduction(reduction_) - { - - } - - }; - -private: - /// Kernel parameters object - typename ReductionKernel::Params params_; - -public: - /// Constructs Reduction SplitK - ReduceSplitK() { } - - /// Determines whether the ReduceSplitK can execute the given problem. - static Status can_implement(Arguments const &args) { - - return Status::kSuccess; - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - // needs no additional workspace - return 0; - } - - /// Initializes Reduction state from arguments. - Status initialize( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - // initialize the params structure from the arguments - params_ = typename ReductionKernel::Params( - args.problem_size, - args.partitions, - args.partition_stride, - args.workspace, - args.destination, - args.source, - args.output, - args.reduction - ); - - return Status::kSuccess; - - } - - /// Initializes Reduction kernel state from arguments. - Status update(Arguments const &args, void *workspace = nullptr) { - - // update the params structure from the arguments - params_.workspace.reset(args.workspace.non_const_ref().data()); - params_.destination.reset(args.destination.non_const_ref().data()); - params_.source.reset(args.source.non_const_ref().data()); - params_.output = args.output; - params_.reduction = args.reduction; - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - - // - // Launch reduction kernel - // - dim3 block = ReductionKernel::block_shape(); - dim3 grid = ReductionKernel::grid_shape(params_.problem_size); - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - void* kernel_params[] = {¶ms_}; - cuda_adapter->launch( - grid, dim3(1,1,1), block, 0, stream, kernel_params, kernel_index); - } - } - else { - cutlass::arch::synclog_setup(); - Kernel<<< grid, block, 0, stream >>>(params_); - } - - cudaError_t result = cudaGetLastError(); - return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; - } - - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - return run(stream, cuda_adapter, kernel_index); - } - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) { - status = run(stream,cuda_adapter, kernel_index); - } - - return status; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace reduction -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h deleted file mode 100644 index 26a0249e9c259dbf2930832d2819188ec74bda60..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h +++ /dev/null @@ -1,264 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over one or more ranks of an affine tensor -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/reduction/device/tensor_reduce_affine_strided.h" -#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tensor reduction operator on specific CUTLASS layouts over exactly one index -template < - typename ElementOutput_, - typename ElementSource_, - typename Layout_, - typename ReductionOp_, - int VectorLength_ = 1, - typename ElementCompute_ = ElementOutput_ -> -struct TensorReduction { - - using ElementOutput = ElementOutput_; - using ElementSource = ElementSource_; - using Layout = Layout_; - using ReductionOp = ReductionOp_; - static int const kVectorLength = VectorLength_; - using ElementCompute = ElementCompute_; - - using TensorCoord = typename Layout::TensorCoord; - - /// Reduction operator - using ReductionDeviceStridedOperator = TensorReductionAffineStrided< - 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute - >; - - using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous< - 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute - >; - - // - // Data members - // - - ReductionDeviceStridedOperator reduction_strided; - ReductionDeviceContiguousOperator reduction_contiguous; - int reduction_index; - - // - // Methods - // - - /// - TensorReduction( - TensorCoord extent, - int reduction_index_ - ): - reduction_index(reduction_index_) { - - Coord<4> extent_affine; - - switch (reduction_index) { - case 0: - extent_affine[0] = extent[1]; - extent_affine[1] = extent[2]; - extent_affine[2] = extent[0]; - extent_affine[3] = extent[3]; - break; - case 1: - extent_affine[0] = extent[0]; - extent_affine[1] = extent[2]; - extent_affine[2] = extent[1]; - extent_affine[3] = extent[3]; - break; - case 2: - extent_affine[0] = extent[0]; - extent_affine[1] = extent[1]; - extent_affine[2] = extent[2]; - extent_affine[3] = extent[3]; - break; - case 3: - extent_affine[0] = extent[0]; - extent_affine[1] = extent[1]; - extent_affine[2] = extent[2]; - extent_affine[3] = extent[3]; - break; - default: break; - } - - if (reduction_index == 3) { - reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine); - } - else { - reduction_strided = ReductionDeviceStridedOperator(extent_affine); - } - } - - /// Simple check to verify the object is initialized correctly - bool good() const { - if (reduction_index == 3) { - return reduction_contiguous.good(); - } - return reduction_strided.good(); - } - - /// Size of one workspace - int64_t workspace_stride() const { - if (reduction_index == 3) { - return reduction_contiguous.workspace_stride(); - } - else { - return reduction_strided.workspace_stride(); - } - } - - /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs - int64_t workspace_size() const { - if (reduction_index == 3) { - return reduction_contiguous.workspace_size(); - } - else { - return reduction_strided.workspace_size(); - } - } - - /// Helper to use overloaded function call operator - Status reduce( - TensorRef dst_ref, - TensorRef src_ref, - void *device_workspace_ptr = nullptr, - ElementCompute reduction_identity = ElementCompute(), - ReductionOp reduction_op = ReductionOp(), - cudaStream_t stream = nullptr) { - - int64_t src_stride[3]; - int64_t dst_stride[3]; - - switch (reduction_index) { - case 0: - src_stride[0] = src_ref.stride()[1]; - src_stride[1] = src_ref.stride()[0]; - src_stride[2] = src_ref.stride()[2]; - dst_stride[0] = dst_ref.stride()[1]; - dst_stride[1] = dst_ref.stride()[0]; - break; - case 1: - src_stride[0] = src_ref.stride()[2]; - src_stride[1] = src_ref.stride()[0]; - src_stride[2] = src_ref.stride()[1]; - dst_stride[0] = dst_ref.stride()[2]; - dst_stride[1] = dst_ref.stride()[0]; - break; - case 2: - src_stride[0] = src_ref.stride()[2]; - src_stride[1] = src_ref.stride()[1]; - src_stride[2] = src_ref.stride()[0]; - dst_stride[0] = dst_ref.stride()[2]; - dst_stride[1] = dst_ref.stride()[1]; - break; - case 3: - src_stride[0] = src_ref.stride()[2]; - src_stride[1] = src_ref.stride()[1]; - src_stride[2] = src_ref.stride()[0]; - - dst_stride[0] = dst_ref.stride()[2]; - dst_stride[1] = dst_ref.stride()[1]; - dst_stride[2] = dst_ref.stride()[0]; - - default: break; - } - - if (reduction_index == 3) { - return reduction_contiguous( - dst_ref.data(), - dst_stride, - src_ref.data(), - src_stride, - device_workspace_ptr, - reduction_identity, - reduction_op, - stream); - } - else { - return reduction_strided( - dst_ref.data(), - dst_stride, - src_ref.data(), - src_stride, - device_workspace_ptr, - reduction_identity, - reduction_op, - stream); - } - } - - Status operator()( - TensorRef dst_ref, - TensorRef src_ref, - void *device_workspace_ptr = nullptr, - ElementCompute reduction_identity = ElementCompute(), - ReductionOp reduction_op = ReductionOp(), - cudaStream_t stream = nullptr) { - - return reduce( - dst_ref, - src_ref, - device_workspace_ptr, - reduction_identity, - reduction_op, - stream); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h deleted file mode 100644 index c00c368165902bdda08f6316a07be19668dc0fb9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h +++ /dev/null @@ -1,374 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over one or more ranks of an affine tensor -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tensor reduction operator on layouts which are affine -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2) - typename ElementOutput_, - typename ElementSource_, - typename ReductionOp_, - int VectorLength = 1, - typename ElementCompute_ = ElementOutput_, - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -struct TensorReductionAffineContiguous { - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - - using ElementOutput = ElementOutput_; - using ElementSource = ElementSource_; - using ReductionOp = ReductionOp_; - using ElementCompute = ElementCompute_; - - // - // Data members - // - - /// Internal status field - Status status; - - /// Extent of tensor in source layout - Coord extent; - - /// Number of points in the outer index space - int64_t outer_count; - - /// Number of elements in the inner index space - int64_t inner_count; - - /// Number of workspaces needed - int workspace_count; - - /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) - dim3 grid_shape; - - /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) - dim3 threadblock_shape; - - /// CUDA grid shape for the final reduction step if needed - dim3 grid_final; - - /// CUDA threadblock shape for the final reduction step if needed - dim3 threadblock_final; - -private: - // - // Methods - // - - /// Helper to reshape 'count' such that it is less than 2 x 'ext' - static int reshape_pow2(int ext, int count) { - if (ext > count) { - return 1; - } - int x = 1; - for (; count >= ext * 2; ) { - count >>= 1; - x <<= 1; - } - return x; - } - -public: - - /// Default ctor - TensorReductionAffineContiguous(): - status(Status::kErrorInvalidProblem), - extent(), - outer_count(0), - inner_count(0), - workspace_count(0), - grid_shape(0, 0, 0), - threadblock_shape(0, 0, 0) { } - - /// Constructor - TensorReductionAffineContiguous( - Coord extent_, - int target_threadblock_count = 128 - ): - status(Status::kSuccess), - extent(extent_), - outer_count(0), - inner_count(0), - workspace_count(0) { - - // - // Plan the parallel mapping strategy. - // - - outer_count = 1; - inner_count = 1; - - // Compute number of elements in strided ranks - for (int p = 0; p < kReducedRank; ++p) { - outer_count *= extent[p]; - } - - for (int p = 0; p < kInnerRank; ++p) { - inner_count *= extent[kReducedRank + p]; - } - - int cta_count_x = 1; - int cta_count_y = 1; - int cta_count_z = 1; - - int cta_threads_x = kThreads; - int cta_threads_y = 1; - int cta_threads_z = 1; - - // Determine CTA shape - int64_t inner_vector_count = inner_count / kVectorLength; - - // Priority 1. Assign threadblocks to outer indices if possible - if (outer_count > target_threadblock_count) { - cta_count_x = 1; - cta_count_y = target_threadblock_count; - cta_count_z = 1; - } - else { - - cta_count_y = int(outer_count); - int remaining_ctas = target_threadblock_count / cta_count_y; - - // Priority 2. Assign inner dimensions to one CTA - if (inner_vector_count > cta_threads_x) { - int64_t cta_z_bound = inner_vector_count / cta_threads_x; - if (cta_z_bound > remaining_ctas) { - cta_count_z = remaining_ctas; - } - else { - cta_count_z = int(cta_z_bound); - } - } - else { - cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x); - cta_count_z = 1; - } - } - - grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); - threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z); - - workspace_count = (cta_count_z > 1 ? cta_count_z : 0); - - // Determine shape of final reduction kernel if needed - if (workspace_count) { - - int final_threads = kThreads; - int final_ctas = 1; - - if (outer_count > kThreads) { - final_ctas = int(outer_count + kThreads - 1) / kThreads; - } - else { - final_threads = int(outer_count); - } - - grid_final = dim3(final_ctas, 1, 1); - threadblock_final = dim3(final_threads, 1, 1); - } - else { - grid_final = dim3(0, 0, 0); - threadblock_final = dim3(0, 0, 0); - } - } - - /// Simple check to verify the object is initialized correctly - bool good() const { - return status == Status::kSuccess; - } - - /// Size (in bytes) of workspace elements which are densely packed together - int64_t workspace_stride() const { - - // Error condition - if (!good()) { - return 0; - } - - return outer_count * sizeof_bits::value / 8; - } - - /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs - int64_t workspace_size() const { - - // Error condition - if (!good()) { - return 0; - } - - // No reduction across CTAs - if (grid_shape.z == 1) { - return 0; - } - - return workspace_stride() * grid_shape.z; - } - - /// Performs a reduction - Status reduce( - ElementOutput *dst_ptr, ///< Pointer to destination tensor - int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) - ElementSource const *src_ptr, ///< Pointer to source tensor - int64_t src_stride[], ///< Stride vector (of length kRank - 1) - void *device_workspace_ptr = nullptr, ///< Device workspace - ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element - ReductionOp reduction_op = ReductionOp(), ///< Reduction operator - cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched - - // Initial status check - if (!good()) { - return status; - } - - // Guard against null workspace - if (workspace_count > 1 && device_workspace_ptr == nullptr) { - return Status::kErrorWorkspaceNull; - } - - // Define reduction kernel - using ReductionKernel = kernel::TensorReductionAffineContiguous< - kRank, - kReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - kVectorLength, - ElementCompute, - kThreads>; - - using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal< - kRank, - kReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - kVectorLength, - ElementCompute, - kThreads>; - - using Params = typename ReductionKernel::Params; - - // Construct the parameters - Params params( - extent, - dst_ptr, - dst_stride, - src_ptr, - src_stride, - static_cast(device_workspace_ptr), - workspace_stride(), - workspace_count, - reduction_op, - reduction_identity); - - // Shared memory size - int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); - - // Launch the kernel - cutlass::arch::synclog_setup(); - Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); - - // Check error condition - if (cudaPeekAtLastError() == cudaSuccess) { - status = Status::kSuccess; - } - else { - status = Status::kErrorInternal; - } - - // Final reduction kernel - if (workspace_count) { - Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); - } - - // Check error condition - if (cudaPeekAtLastError() == cudaSuccess) { - status = Status::kSuccess; - } - else { - status = Status::kErrorInternal; - } - - return status; - } - - /// Helper to use overloaded function call operator - Status operator()( - ElementOutput *dst_ptr, ///< Pointer to destination tensor - int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) - ElementSource const *src_ptr, ///< Pointer to source tensor - int64_t src_stride[], ///< Stride vector (of length kRank - 1) - void *device_workspace_ptr = nullptr, ///< Pointer to device workspace - ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element - ReductionOp reduction_op = ReductionOp(), ///< Reduction operator - cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched - - return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h deleted file mode 100644 index c85d6dcbf13ba17a82b252124313c58f901e55f5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h +++ /dev/null @@ -1,362 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over one or more ranks of an affine tensor -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tensor reduction operator on layouts which are affine -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput_, - typename ElementSource_, - typename ReductionOp_, - int VectorLength = 1, - typename ElementCompute_ = ElementOutput_, - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -struct TensorReductionAffineStrided { - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - - using ElementOutput = ElementOutput_; - using ElementSource = ElementSource_; - using ReductionOp = ReductionOp_; - using ElementCompute = ElementCompute_; - - // - // Data members - // - - /// Internal status field - Status status; - - /// Extent of tensor in source layout - Coord extent; - - /// Number of points in the outer index space - int64_t outer_count; - - /// Number of elements in the inner index space - int64_t inner_count; - - /// Number of workspaces needed - int workspace_count; - - /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) - dim3 grid_shape; - - /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) - dim3 threadblock_shape; - - /// CUDA grid shape for the final reduction step if needed - dim3 grid_final; - - /// CUDA threadblock shape for the final reduction step if needed - dim3 threadblock_final; - -private: - // - // Methods - // - - /// Helper to reshape 'count' such that it is less than 2 x 'ext' - static int reshape_pow2(int ext, int count) { - if (ext > count) { - return 1; - } - int x = 1; - for (; count >= ext * 2; ) { - count >>= 1; - x <<= 1; - } - return x; - } - -public: - - /// Default ctor - TensorReductionAffineStrided(): - status(Status::kErrorInvalidProblem), - extent(), - outer_count(0), - inner_count(0), - workspace_count(0), - grid_shape(0, 0, 0), - threadblock_shape(0, 0, 0) { } - - /// Constructor - TensorReductionAffineStrided( - Coord extent_, - int target_threadblock_count = 128 - ): - status(Status::kSuccess), - extent(extent_), - outer_count(0), - inner_count(0), - workspace_count(0) { - - // - // Plan the parallel mapping strategy. - // - - outer_count = 1; - inner_count = 1; - - // Compute number of elements in strided ranks - for (int p = 0; p < kReducedRank - 1; ++p) { - outer_count *= extent[p]; - } - - for (int p = 0; p < kInnerRank; ++p) { - inner_count *= extent[kReducedRank + p - 1]; - } - - // Compute plan for the reduction - int extent_c = extent[kRank - 1]; - int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength; - - // Determine CTA shape - int cta_width = kThreads * kVectorLength; - int cta_ways = reshape_pow2(extent_c, cta_width); - int cta_threads_x = kThreads / cta_ways; - - threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64)); - - // This leads to an error. - if (threadblock_shape.z > 1) { - if (threadblock_shape.y != 1) { - status = Status::kErrorInternal; - return; - } - } - - // Determine grid shape - int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x; - int cta_count_y = std::max(1, target_threadblock_count / cta_count_x); - - // Limit the number of CTAs assigned to outer dimension - if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) { - cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y; - } - - // Limit the number of CTAs assigned to inner dimension - int cta_count_z = std::max(1, target_threadblock_count / cta_count_y); - if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) { - cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z; - } - - grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); - workspace_count = (cta_count_z > 1 ? cta_count_z : 0); - - // Determine shape of final reduction kernel if needed - grid_final = dim3(cta_count_x, int(outer_count)); - threadblock_final = dim3(cta_threads_x, 1, 1); - } - - /// Simple check to verify the object is initialized correctly - bool good() const { - return status == Status::kSuccess; - } - - /// Size of one CTA's workspace - int64_t workspace_stride() const { - - // Error condition - if (!good()) { - return 0; - } - - int vector_size_bytes = kVectorLength * sizeof_bits::value / 8; - - return extent[kRank - 1] * vector_size_bytes; - } - - /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs - int64_t workspace_size() const { - - // Error condition - if (!good()) { - return 0; - } - - // No reduction across CTAs - if (grid_shape.z == 1) { - return 0; - } - - return workspace_stride() * outer_count * grid_shape.z; - } - - /// Performs a reduction - Status reduce( - ElementOutput *dst_ptr, ///< Pointer to destination tensor - int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) - ElementSource const *src_ptr, ///< Pointer to source tensor - int64_t src_stride[], ///< Stride vector (of length kRank - 1) - void *device_workspace_ptr = nullptr, ///< Device workspace - ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity - ReductionOp reduction_op = ReductionOp(), ///< Reduction operator - cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched - - // Initial status check - if (!good()) { - return status; - } - - // Guard against null workspace - if (workspace_count > 1 && device_workspace_ptr == nullptr) { - return Status::kErrorWorkspaceNull; - } - - // Define reduction kernel - using ReductionKernel = kernel::TensorReductionAffineStrided< - kRank, - kReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - kVectorLength, - ElementCompute, - kThreads>; - - using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal< - kRank, - kReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - kVectorLength, - ElementCompute, - kThreads>; - - using Params = typename ReductionKernel::Params; - - // Construct the parameters - Params params( - extent, - dst_ptr, - dst_stride, - src_ptr, - src_stride, - static_cast(device_workspace_ptr), - workspace_stride(), - workspace_count, - reduction_op, - reduction_identity); - - // Shared memory size - int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); - - // Launch the kernel - cutlass::arch::synclog_setup(); - Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); - - // Check error condition - if (cudaPeekAtLastError() == cudaSuccess) { - status = Status::kSuccess; - } - else { - status = Status::kErrorInternal; - } - - // Final reduction kernel - if (workspace_count) { - - Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); - - // Check error condition - if (cudaPeekAtLastError() == cudaSuccess) { - status = Status::kSuccess; - } - else { - status = Status::kErrorInternal; - } - } - - return status; - } - - /// Helper to use overloaded function call operator - Status operator()( - ElementOutput *dst_ptr, ///< Pointer to destination tensor - int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) - ElementSource const *src_ptr, ///< Pointer to source tensor - int64_t src_stride[], ///< Stride vector (of length kRank - 1) - void *device_workspace_ptr = nullptr, ///< Pointer to device workspace - ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity - ReductionOp reduction_op = ReductionOp(), ///< Reduction operator - cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched - - return reduce( - dst_ptr, - dst_stride, - src_ptr, - src_stride, - device_workspace_ptr, - reduction_identity, - reduction_op, - stream); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h deleted file mode 100644 index 3d39dc751c4bdef328398c5a94e5462136728f6a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h +++ /dev/null @@ -1,267 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a final reduction for softmax -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace kernel { - -template < - typename ElementNorm_, - typename ElementSum_, - typename ElementSoftmaxCompute_, - typename ThreadblockShape_, - bool GroupedProblem = false -> -class ApplySoftmaxFinalReduction { -public: - - using ElementNorm = ElementNorm_; - using ElementSum = ElementSum_; - using ElementSoftmaxCompute = ElementSoftmaxCompute_; - using ThreadblockShape = ThreadblockShape_; - static const bool isGroupedProblem = GroupedProblem; - - // - // Arguments - // - - struct Arguments { - - cutlass::gemm::GemmCoord* problem_sizes{nullptr}; - cutlass::gemm::GemmCoord problem_size{}; - ElementNorm* block_Norm{nullptr}; - ElementSum* block_Sum{nullptr}; - int64_t* offset_Norm_Device{nullptr}; - int64_t* offset_Sum_Device{nullptr}; - int64_t batch_stride_Max{0}; - int64_t batch_stride_Sum{0}; - - // - // Methods - // - Arguments() { } - - // Non-grouped constructor without batching - Arguments( - cutlass::gemm::GemmCoord problem_size, - ElementNorm* block_Norm, - ElementSum* block_Sum - ): - problem_size(problem_size), - block_Norm(block_Norm), - block_Sum(block_Sum), - problem_sizes(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr), - batch_stride_Max(0), - batch_stride_Sum(0) - { - - } - - // Non-grouped constructor with batching - Arguments( - cutlass::gemm::GemmCoord problem_size, - ElementNorm* block_Norm, - ElementSum* block_Sum, - int64_t batch_stride_Max, - int64_t batch_stride_Sum - ): - problem_size(problem_size), - block_Norm(block_Norm), - block_Sum(block_Sum), - batch_stride_Max(batch_stride_Max), - batch_stride_Sum(batch_stride_Sum), - problem_sizes(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr) - { - - } - - - // Grouped constructor - Arguments( - cutlass::gemm::GemmCoord *problem_sizes, - ElementNorm* block_Norm, - ElementSum* block_Sum, - int64_t* offset_Norm_Device, - int64_t* offset_Sum_Device - ): - problem_sizes(problem_sizes), - problem_size(cutlass::gemm::GemmCoord(0, 0, 0)), - block_Norm(block_Norm), - block_Sum(block_Sum), - offset_Norm_Device(offset_Norm_Device), - offset_Sum_Device(offset_Sum_Device) - { - - } - }; - - struct SharedStorage { - - - }; - - // - // Params struct - // - - struct Params { - Arguments args; - - // - // Methods - // - Params() { } - - Params(Arguments const &args_): args(args_) { } - }; - -private: - -public: - - CUTLASS_DEVICE - ApplySoftmaxFinalReduction() { } - - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - apply(params, shared_storage); - } - -private: - - /// Full reduction - CUTLASS_DEVICE - void apply(Params const ¶ms, SharedStorage &shared_storage) { - - int tid = threadIdx.x; - int bid = blockIdx.x; - int bdim = blockDim.x; - - int block_batch = blockIdx.z; - - // defining three vars for a general reduction module - cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; - int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; - int access_offset = isGroupedProblem ? 0 : bid * bdim; - - if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; - - ElementNorm *curr_ptr_Max = isGroupedProblem ? \ - params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ - params.args.block_Norm + block_batch * params.args.batch_stride_Max; - ElementSum *curr_ptr_Sum = isGroupedProblem ? \ - params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ - params.args.block_Sum + block_batch * params.args.batch_stride_Sum; - - int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - using ConvertSumOutput = cutlass::NumericConverter; - using ConvertNormOutput = cutlass::NumericConverter; - - using ConvertSum = cutlass::NumericConverter; - using ConvertNorm = cutlass::NumericConverter; - - ConvertSum convert_sum; - ConvertNorm convert_norm; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - uint32_t float_max_bits = 0xff7fffff; - float min_float = reinterpret_cast(float_max_bits); - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { - ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; - ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; - ElementNorm *access_n_bak = access_n; - ElementSum *access_s_bak = access_s; - ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); - ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); - ElementNorm fetch_n; - ElementSum fetch_s; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); - access_n += problem_size.m(); - } - - access_n = access_n_bak; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - cutlass::arch::global_load(fetch_s, access_s, true); - sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); - access_n += problem_size.m(); - access_s += problem_size.m(); - } - - ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; - - access_n = access_n_bak; - access_s = access_s_bak; - - access_n[0] = convert_norm_output(max_val); - access_s[0] = convert_sum_output(inv_sum); - } - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace reduction -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h deleted file mode 100644 index f6d26666957a58321c579b191ec06c84503e8ca2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h +++ /dev/null @@ -1,248 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over densely packed tensors in global memory -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/layout/matrix.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, ///< shape of CTA (concept: MatrixShape) - typename OutputOp_ , ///< output operator (concept: epilogue::thread operator) - typename ReductionOp_, ///< reduction operator (concept: ReductionOperator) - int PartitionsPerStage = 4 ///< number of partitions to issue -> -class ReduceSplitK { -public: - - using Shape = Shape_; - using ReductionOp = ReductionOp_; - using OutputOp = OutputOp_; - static int const kElementsPerAccess = OutputOp::kCount; - static int const kPartitionsPerStage = PartitionsPerStage; - - using ElementWorkspace = typename ReductionOp::Element; - using ElementAccumulator = typename ReductionOp::ElementAccumulator; - using ElementOutput = typename OutputOp::ElementOutput; - - using WorkspaceTensorRef = TensorRef; - using OutputTensorRef = TensorRef; - using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index; - - using FragmentWorkspace = AlignedArray; - using FragmentAccumulator = Array; - using FragmentOutput = AlignedArray; - - // - // Types - // - - /// Params structure - struct Params { - - MatrixCoord problem_size; - int partitions; - size_t partition_stride; - WorkspaceTensorRef workspace; - OutputTensorRef destination; - OutputTensorRef source; - typename OutputOp::Params output; - typename ReductionOp::Params reduction; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params( - MatrixCoord problem_size_, - int partitions_, - size_t partition_stride_, - WorkspaceTensorRef workspace_, - OutputTensorRef destination_, - OutputTensorRef source_, - typename OutputOp::Params output_ = typename OutputOp::Params(), - typename ReductionOp::Params reduction_ = typename ReductionOp::Params() - ): - problem_size(problem_size_), - partitions(partitions_), - partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess), - workspace(workspace_), - destination(destination_), - source(source_), - output(output_), - reduction(reduction_) { - - } - }; - - struct SharedStorage { }; - - -public: - - /// Computes the grid size given a chosen threadblock shape - CUTLASS_HOST_DEVICE - static dim3 grid_shape( - cutlass::MatrixCoord problem_size) { - - return dim3( - (problem_size.row() + Shape::kRow - 1) / Shape::kRow, - (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); - } - - /// Determines the threadblock shape - CUTLASS_HOST_DEVICE - static dim3 block_shape() { - return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow); - } - - /// Perform a reduction - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &storage) { - - // Determine CTA position - MatrixCoord thread_offset( - MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y), - MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess) - ); - - // One guard conditional - if (!(thread_offset.row() < params.problem_size.row() && - thread_offset.column() < params.problem_size.column())) { - - return; - } - - - ReductionOp reduction_op(params.reduction); - - FragmentAccumulator accumulator; - - accumulator.clear(); - - // - // Load the first slice - // - - char const *workspace_ptr = - reinterpret_cast( - params.workspace.data() + params.workspace.offset(thread_offset)); - - FragmentWorkspace workspace_frag[kPartitionsPerStage]; - - // - // Construct the output operator - // - - OutputOp output_op(params.output); - - // - // Load and accumulate with a simple batched loading sequence. - // - - CUTLASS_PRAGMA_NO_UNROLL - for (int k = 0; k < params.partitions; k += kPartitionsPerStage) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPartitionsPerStage; ++i) { - if (k + i < params.partitions) { - workspace_frag[i] = *reinterpret_cast(workspace_ptr); - workspace_ptr += params.partition_stride; - } - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPartitionsPerStage; ++i) { - if (k + i < params.partitions) { - accumulator = reduction_op(accumulator, workspace_frag[i]); - } - } - } - - // - // Conditionally load the source - // - - FragmentOutput source_frag; - - source_frag.clear(); - - FragmentOutput const *source_ptr = reinterpret_cast( - params.source.data() + params.source.offset(thread_offset)); - - if (output_op.is_source_needed()) { - reinterpret_cast(source_frag) = *source_ptr; - } - - // - // Compute the output - // - - typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag); - - // - // Store - // - - FragmentOutput *dest_ptr = reinterpret_cast( - params.destination.data() + params.destination.offset(thread_offset)); - - *dest_ptr = reinterpret_cast(output_frag); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace reduction -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h deleted file mode 100644 index 914bbddda9227d1f1772d8e8171b06280b7a5f61..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h +++ /dev/null @@ -1,606 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over one or more ranks of an affine tensor -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/reduction/thread/reduction_operators.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parameters structure -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -struct TensorReductionAffineContiguousParams { - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - - Coord extent; /// Extent of source tensor - FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank - int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J - int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K - int64_t workspace_stride; /// stride (units of bytes) between workspace - int workspace_count; /// number of workspaces - - uint64_t inner_count; /// Number of elements in reduced index space - uint64_t outer_count; /// Number of elements in outer index space - - ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank - ElementSource const * source; /// Pointer to source pointer of rank kRank - ReductionOp reduction_op; /// Reduction operator - ElementCompute reduction_identity; /// Identity element used by reduction operator - ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorReductionAffineContiguousParams() { - - } - - /// Ctor - TensorReductionAffineContiguousParams( - Coord extent_, ///< Extent of source tensor - ElementOutput * dst_ptr_, ///< Output tensor data - int64_t dst_stride_[], ///< Stride (units of elements) - ElementSource const * src_ptr_, ///< Source tensor data - int64_t src_stride_[], ///< Stride (units of elements) - ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions - int64_t workspace_stride_, ///< Stride between workspaces - int workspace_count_, ///< Number of workspaces - ReductionOp reduction_op_, ///< Reduction operator - ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator - ): - extent(extent_), - inner_count(1), - outer_count(1), - destination(dst_ptr_), - source(src_ptr_), - device_workspace(device_workspace_), - workspace_stride(workspace_stride_), - workspace_count(workspace_count_), - reduction_op(reduction_op_), - reduction_identity(reduction_identity_) { - - // Initialize divisors for fast div-mod - for (int p = 1; p < kRank; ++p) { - divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); - } - - int input_size_bits = sizeof_bits::value; - int output_size_bits = sizeof_bits::value; - - // Compute strides in units of bytes - for (int p = 0; p < kReducedRank; ++p) { - dst_stride[p] = dst_stride_[p] * output_size_bits / 8; - } - - for (int p = 0; p < kRank - 1; ++p) { - src_stride[p] = src_stride_[p] * input_size_bits / 8; - } - - // Compute number of elements in strided ranks - for (int p = 0; p < kReducedRank; ++p) { - outer_count *= uint64_t(extent[p]); - } - - for (int p = 0; p < kInnerRank; ++p) { - inner_count *= uint64_t(extent[kRank - 1 - p]); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous -/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -class TensorReductionAffineContiguous { -public: - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - using ComputeFragment = Array; - using SourceFragment = AlignedArray; - using OutputFragment = AlignedArray; - - /// Shared memory allocation used for reduction within the CTA - struct SharedStorage { - Array workspace; - }; - - /// Parameters structure - using Params = TensorReductionAffineContiguousParams< - Rank, - ReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - VectorLength, - ElementCompute, - Threads, - BatchSize - >; - -private: - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_inner_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &src_offset, - uint64_t linear_idx) const { - - // Decompose into a coordinate of rank - coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); - - // Compute an offset using the souce stride - src_offset = 0; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kInnerRank - 1; ++i) { - src_offset += coord[i] * params.src_stride[kReducedRank + i]; - } - src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; - } - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_outer_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &dst_offset, - int64_t &src_offset, - uint64_t linear_idx) const { - - // Decompose into coordinate of rank - coord = CoordinateDecomposition(linear_idx, params.divmod); - - // Compute offsets using destination and source strides - dst_offset = 0; - src_offset = 0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kReducedRank; ++i) { - dst_offset += params.dst_stride[i] * coord[i]; - src_offset += params.src_stride[i] * coord[i]; - } - } - - /// Reduces over the reduction indices yielding a single element - CUTLASS_DEVICE - ElementCompute reduce_indices_( - Params const ¶ms, - ElementCompute *threadblock_workspace, - char const *src_byte_ptr, - int coord_c) { - - NumericArrayConverter convert_source; - ReductionOp reduction_op(params.reduction_op); - - // - // Early exit or initialize to identity element - // - if (!params.inner_count) { - return params.reduction_identity; - } - - ComputeFragment accumulator; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(accumulator.size()); ++i) { - accumulator[i] = params.reduction_identity; - } - - // Compute the coordinate of the first access - int64_t src_byte_offset = 0; - Coord coord; - - uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; - compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); - - // Load the first vector - SourceFragment source_fragment[kBatchSize]; - - bool not_done = true; - - // Iterate over vectors in a linearized reduction index space - while (not_done) { - - bool guards[kBatchSize]; - - // Issue a batch of loads - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - - if (linear_idx < params.inner_count) { - source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); - guards[b] = true; - } - else { - guards[b] = false; - not_done = false; - } - - linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; - compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); - } - - // Perform a batch of reduction operations - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - if (guards[b]) { - auto cvt = convert_source(source_fragment[b]); - - accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( - reduction_op, - accumulator, - cvt); - } - } - }; - - // - // Reduction of vectors to scalar - // - - ElementCompute reduced_accumulator = accumulator[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kVectorLength; ++i) { - reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); - } - - // - // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} - // - // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column - // - - int thread_count = blockDim.x * blockDim.z; - int thread_j = threadIdx.x + blockDim.x * threadIdx.z; - int thread_i = threadIdx.y; - - ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; - - frag_ptr[thread_j] = reduced_accumulator; - - // - // Reduce - // - CUTLASS_PRAGMA_NO_UNROLL - while (thread_count > 1) { - thread_count /= 2; - - __syncthreads(); - - if (thread_j < thread_count) { - ElementCompute other = frag_ptr[thread_j + thread_count]; - - reduced_accumulator = reduction_op(reduced_accumulator, other); - - frag_ptr[thread_j] = reduced_accumulator; - } - - __syncthreads(); - } - - - return reduced_accumulator; - } - -public: - - /// Perform a reduction - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; - - char const * src_byte_ptr = reinterpret_cast(params.source); - char * dst_byte_ptr = nullptr; - - // If performing a reduction across CTAs, redirect output to device workspace - if (gridDim.z == 1) { - dst_byte_ptr = reinterpret_cast(params.destination); - } - else { - dst_byte_ptr = reinterpret_cast(params.device_workspace); - } - - uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; - - // Use modulo division to compute location - Coord outer_coord; - int64_t dst_byte_offset; - int64_t src_byte_offset; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - - if (gridDim.z == 1) { - - /// Complete the reduction with no workspace - while (idx_linear < params.outer_count) { - - ElementCompute result = reduce_indices_( - params, - shared_storage.workspace.data(), - src_byte_ptr + src_byte_offset, - coord_c); - - // Store the result after possible final reduction within the CTA - if (threadIdx.z == 0 && threadIdx.x == 0) { - - // Convert to output type and store - NumericConverter convert_output; - ElementOutput cvt = convert_output(result); - - *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; - } - - __syncthreads(); - - // Update indices and pointers - idx_linear += gridDim.y * blockDim.y; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - - } // while - } - else { - - /// Complete the reduction with workspace - while (idx_linear < params.outer_count) { - - ElementCompute result = reduce_indices_( - params, - shared_storage.workspace.data(), - src_byte_ptr + src_byte_offset, - coord_c); - - int64_t byte_offset = - blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; - - // Store the result for final reduction - if (threadIdx.z == 0 && threadIdx.x == 0) { - *reinterpret_cast(dst_byte_ptr + byte_offset) = result; - } - - __syncthreads(); - - // Update indices and pointers - idx_linear += gridDim.y * blockDim.y; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - } // while - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to perform final reduction -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -class TensorReductionAffineContiguousFinal { -public: - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - - /// Shared memory - struct SharedStorage { }; - - /// Parameters structure - using Params = TensorReductionAffineContiguousParams< - Rank, - ReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - VectorLength, - ElementCompute, - Threads, - BatchSize - >; - -private: - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_outer_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &dst_offset, - uint64_t linear_idx) const { - - // Decompose into coordinate of rank - coord = CoordinateDecomposition(linear_idx, params.divmod); - - // Compute offsets using destination and source strides - dst_offset = 0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kReducedRank; ++i) { - dst_offset += params.dst_stride[i] * coord[i]; - } - } - - /// Reduces over the reduction indices - CUTLASS_DEVICE - ElementCompute reduce_indices_( - Params const ¶ms, - ElementCompute const *device_workspace) { - - ReductionOp reduction_op(params.reduction_op); - char const *src_byte_ptr = reinterpret_cast(device_workspace); - - // Accumulated output - ElementCompute accumulator = params.reduction_identity; - - for (int iter = 0; iter < params.workspace_count; ++iter) { - ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); - - accumulator = reduction_op(accumulator, workspace_item); - - src_byte_ptr += params.workspace_stride; - } - - return accumulator; - } - -public: - - // - // Methods - // - - /// Perform a reduction - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; - - char * dst_byte_ptr = reinterpret_cast(params.destination); - - // Use modulo division to compute location - Coord outer_coord; - int64_t dst_byte_offset; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - idx_linear); - - /// Complete the reduction - while (idx_linear < params.outer_count) { - - ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); - - // Convert to output type and store - NumericConverter convert_output; - - *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); - - // Update indices and pointers - idx_linear += gridDim.x * blockDim.x; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - idx_linear); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h deleted file mode 100644 index 0538184f3886b53207cc28a46a9fb8b04d3e8c5e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h +++ /dev/null @@ -1,641 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over one or more ranks of an affine tensor -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/reduction/thread/reduction_operators.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -/// Parameters structure -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -struct TensorReductionAffineStridedParams { - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - - Coord extent; /// Extent of source tensor - FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank - int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J - int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K - int64_t workspace_stride; /// stride (units of bytes) between workspace - int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace - int workspace_count; /// number of workspaces - - uint64_t inner_count; /// Number of elements in reduced index space - uint64_t outer_count; /// Number of elements in outer index space - - ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank - ElementSource const * source; /// Pointer to source pointer of rank kRank - ReductionOp reduction_op; /// Reduction operator - ElementCompute reduction_identity; /// Identity element for reduction operator - ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - TensorReductionAffineStridedParams() { - - } - - /// Ctor - TensorReductionAffineStridedParams( - Coord extent_, ///< Extent of source tensor - ElementOutput * dst_ptr_, ///< Output tensor data - int64_t dst_stride_[], ///< Stride (units of elements) - ElementSource const * src_ptr_, ///< Source tensor data - int64_t src_stride_[], ///< Stride (units of elements) - ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions - int64_t workspace_stride_, ///< Stride between workspaces - int workspace_count_, ///< Number of workspaces - ReductionOp reduction_op_, ///< Reduction operator - ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator - ): - extent(extent_), - inner_count(1), - outer_count(1), - destination(dst_ptr_), - source(src_ptr_), - device_workspace(device_workspace_), - workspace_outer_stride(0), - workspace_stride(workspace_stride_), - workspace_count(workspace_count_), - reduction_op(reduction_op_), - reduction_identity(reduction_identity_) { - - // Initialize divisors for fast div-mod - for (int p = 1; p < kRank; ++p) { - divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); - } - - int input_size_bits = sizeof_bits::value; - int output_size_bits = sizeof_bits::value; - - workspace_outer_stride = workspace_stride * workspace_count; - - // Compute strides in units of bytes - for (int p = 0; p < kReducedRank - 1; ++p) { - dst_stride[p] = dst_stride_[p] * output_size_bits / 8; - } - - for (int p = 0; p < kRank - 1; ++p) { - src_stride[p] = src_stride_[p] * input_size_bits / 8; - } - - // Compute number of elements in strided ranks - for (int p = 0; p < kReducedRank - 1; ++p) { - outer_count *= uint64_t(extent[p]); - } - - for (int p = 0; p < kInnerRank; ++p) { - inner_count *= uint64_t(extent[kReducedRank + p - 1]); - } - } -}; - -/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous -/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -class TensorReductionAffineStrided { -public: - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - using ComputeFragment = Array; - using SourceFragment = AlignedArray; - using OutputFragment = AlignedArray; - - /// Shared memory allocation used for reduction within the CTA - struct SharedStorage { - Array workspace; - }; - - /// Parameters structure - using Params = TensorReductionAffineStridedParams< - Rank, - ReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - VectorLength, - ElementCompute, - Threads, - BatchSize - >; - -private: - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_inner_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &src_offset, - uint64_t linear_idx) const { - - // Decompose into coordinate - coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); - - // Compute linear offset - src_offset = 0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kInnerRank; ++i) { - src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; - } - } - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_outer_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &dst_offset, - int64_t &src_offset, - uint64_t linear_idx) const { - - // Decompose linear coordinate - coord = CoordinateDecomposition(linear_idx, params.divmod); - - // Compute offset into tensors - dst_offset = 0; - src_offset = 0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kReducedRank - 1; ++i) { - dst_offset += params.dst_stride[i] * coord[i]; - src_offset += params.src_stride[i] * coord[i]; - } - } - - /// Reduces over the reduction indices - CUTLASS_DEVICE - ComputeFragment reduce_indices_( - Params const ¶ms, - ElementCompute *threadblock_workspace, - char const *src_byte_ptr) { - - NumericArrayConverter convert_source; - ReductionOp reduction_op(params.reduction_op); - - // Accumulated output - ComputeFragment identity_frag; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(identity_frag.size()); ++i) { - identity_frag[i] = params.reduction_identity; - } - - if (!params.inner_count) { - return identity_frag; - } - - ComputeFragment accumulator = identity_frag; - - // Compute the coordinate of the first access - int64_t src_byte_offset = 0; - Coord coord; - - uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; - compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); - - // Load the first vector - SourceFragment source_fragment[kBatchSize]; - - bool not_done = true; - - // Iterate over vectors in a linearized reduction index space - while (not_done) { - - bool guards[kBatchSize]; - - // Issue a batch of loads - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - - if (linear_idx < params.inner_count) { - source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); - guards[b] = true; - } - else { - guards[b] = false; - not_done = false; - } - - linear_idx += blockDim.z * gridDim.z; - compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); - } - - // Perform a batch of reduction operations - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - if (guards[b]) { - - auto cvt = convert_source(source_fragment[b]); - - accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( - reduction_op, - accumulator, - cvt); - } - } - }; - - // Optional reduction within a CTA - if (blockDim.z > 1) { - - // Linearized thread ID - int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); - - // all threads store to workspace - ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); - - frag_ptr[thread_idx] = accumulator; - - __syncthreads(); - - if (threadIdx.z == 0) { - // Load all additional block indices - for (int z = 1; z < blockDim.z; ++z) { - ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; - - accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( - reduction_op, - accumulator, - frag); - } - } - - __syncthreads(); - } - - return accumulator; - } - -public: - - /// Perform a reduction - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; - - char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); - char * dst_byte_ptr = nullptr; - - // If performing a reduction across CTAs, redirect output to device workspace - if (gridDim.z == 1) { - dst_byte_ptr = reinterpret_cast(params.destination + coord_c); - } - else { - dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); - } - - // If the C index is out of bounds, exit - if (coord_c >= params.extent[kRank - 1]) { - return; - } - - int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; - - // Use modulo division to compute location - Coord outer_coord; - int64_t dst_byte_offset; - int64_t src_byte_offset; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - - if (gridDim.z == 1) { - - /// Complete the reduction with no workspace - while (idx_linear < params.outer_count) { - - ComputeFragment result; - - result = reduce_indices_( - params, - shared_storage.workspace.data(), - src_byte_ptr + src_byte_offset); - - // Store the result after possible final reduction within the CTA - if (threadIdx.z == 0) { - - // Convert to output type and store - NumericArrayConverter convert_output; - auto cvt = convert_output(result); - - *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = - reinterpret_cast(cvt); - } - - // Update indices and pointers - idx_linear += gridDim.y * blockDim.y; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - - } // while - } - else { - - /// Complete the reduction with a device workspace - while (idx_linear < params.outer_count) { - - ComputeFragment result; - - result = reduce_indices_( - params, - shared_storage.workspace.data(), - src_byte_ptr + src_byte_offset); - - // Store the result after possible final reduction within the CTA - if (threadIdx.z == 0) { - - int64_t byte_offset = - blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; - - // No conversion - store in compute type - *reinterpret_cast(dst_byte_ptr + byte_offset) = - reinterpret_cast(result); - } - - // Update indices and pointers - idx_linear += gridDim.y * blockDim.y; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - src_byte_offset, - idx_linear); - - } // while (outer index) - } // if () - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to perform final reduction -template < - int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) - int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) - typename ElementOutput, ///< Data type of output tensor - typename ElementSource, ///< Data type of source tensor - typename ReductionOp, ///< Reduction operator - int VectorLength = 1, ///< Vector length for memory - typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation - int Threads = 256, ///< Number of participating threads - int BatchSize = 4 ///< Number of elements to load per batch -> -class TensorReductionAffineStridedFinal { -public: - - static int const kRank = Rank; - static int const kReducedRank = ReducedRank; - static int const kVectorLength = VectorLength; - static int const kInnerRank = kRank - kReducedRank; - static int const kThreads = Threads; - static int const kBatchSize = BatchSize; - using ComputeFragment = Array; - using SourceFragment = AlignedArray; - using OutputFragment = AlignedArray; - - /// Shared memory - struct SharedStorage { }; - - /// Parameters structure - using Params = TensorReductionAffineStridedParams< - Rank, - ReducedRank, - ElementOutput, - ElementSource, - ReductionOp, - VectorLength, - ElementCompute, - Threads, - BatchSize - >; - -private: - - /// Computes the coordinate and offset of a given linear index - CUTLASS_DEVICE - void compute_outer_coord_and_offset_( - Params const ¶ms, - Coord & coord, - int64_t &dst_offset, - uint64_t linear_idx) const { - - // Decompose linear index - coord = CoordinateDecomposition(linear_idx, params.divmod); - - // Compute tensor offset - dst_offset = 0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kReducedRank - 1; ++i) { - dst_offset += params.dst_stride[i] * coord[i]; - } - } - - /// Reduces over the reduction indices - CUTLASS_DEVICE - ComputeFragment reduce_indices_( - Params const ¶ms, - char *src_byte_ptr) { - - ReductionOp reduction_op(params.reduction_op); - - // Accumulated output - ComputeFragment identity_frag; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(identity_frag.size()); ++i) { - identity_frag[i] = params.reduction_identity; - } - - ComputeFragment accumulator = identity_frag; - ComputeFragment workspace_fragments[kBatchSize]; - - // Partially unrolled loop - for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { - - // Issue a batch of loads - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - if (idx + b < params.workspace_count) { - workspace_fragments[b] = - *reinterpret_cast(src_byte_ptr); - } - else { - workspace_fragments[b] = identity_frag; - } - src_byte_ptr += + params.workspace_stride; - } - - // Perform a reduction - CUTLASS_PRAGMA_UNROLL - for (int b = 0; b < kBatchSize; ++b) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kVectorLength; ++i) { - accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); - } - } - } - - return accumulator; - } - -public: - - // - // Methods - // - - /// Perform a reduction - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; - - char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); - char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); - - // If the C index is out of bounds, exit - if (coord_c >= params.extent[kRank - 1]) { - return; - } - - int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; - - // Use modulo division to compute location - Coord outer_coord; - int64_t dst_byte_offset; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - idx_linear); - - /// Complete the reduction - while (idx_linear < params.outer_count) { - - int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; - - ComputeFragment result = reduce_indices_( - params, - src_byte_ptr + src_byte_offset); - - // Convert to output type and store - NumericArrayConverter convert_output; - auto cvt = convert_output(result); - - *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = - reinterpret_cast(cvt); - - // Update indices and pointers - idx_linear += gridDim.y * blockDim.y; - - compute_outer_coord_and_offset_( - params, - outer_coord, - dst_byte_offset, - idx_linear); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h deleted file mode 100644 index cc354df56a0fd83f0315370138fca729a2236d79..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h +++ /dev/null @@ -1,234 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic thread level reduction with specializations for Array. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/functional.h" - -namespace cutlass { -namespace reduction { -namespace thread { - -/// Structure to compute the thread level reduction -template -struct Reduce; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial Specialization of Reduce for "plus" (a functional operator) -template -struct Reduce< plus, T > { - - CUTLASS_HOST_DEVICE - T operator()(T lhs, T const &rhs) const { - plus _op; - return _op(lhs, rhs); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization of Reduce for Array -template -struct Reduce < plus, Array> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &in) const { - - Array result; - Reduce< plus, T > scalar_reduce; - result.clear(); - - CUTLASS_PRAGMA_UNROLL - for (auto i = 0; i < N; ++i) { - result[0] = scalar_reduce(result[0], in[i]); - } - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specializations of Reduce for Array -template -struct Reduce < plus, Array > { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &input) { - - Array result; - - // If there is only 1 element - there is nothing to reduce - if( N ==1 ){ - - result[0] = input.front(); - - } else { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) - - __half result_d; - Array const *in_ptr_half = reinterpret_cast const *>(&input); - Array const *in_ptr_half2 = reinterpret_cast const *>(&input); - __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); - - // Set initial result = first half2, in case N==2 - __half2 tmp_result = x_in_half2[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < N/2; ++i) { - - tmp_result = __hadd2(x_in_half2[i], tmp_result); - - } - - result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); - - // One final step is needed for odd "N" (to add the (N-1)th element) - if( N%2 ){ - - __half last_element; - Array tmp_last; - Array *tmp_last_ptr = &tmp_last; - tmp_last_ptr[0] = in_ptr_half[N-1]; - last_element = reinterpret_cast<__half const &>(tmp_last); - - result_d = __hadd(result_d, last_element); - - } - - Array *result_ptr = &result; - *result_ptr = reinterpret_cast &>(result_d); - - #else - - Reduce< plus, half_t > scalar_reduce; - result.clear(); - - CUTLASS_PRAGMA_UNROLL - for (auto i = 0; i < N; ++i) { - - result[0] = scalar_reduce(result[0], input[i]); - - } - - #endif - } - - return result; - - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specializations of Reduce for AlignedArray -template -struct Reduce < plus, AlignedArray > { - - CUTLASS_HOST_DEVICE - Array operator()(AlignedArray const &input) { - - Array result; - - // If there is only 1 element - there is nothing to reduce - if( N ==1 ){ - - result[0] = input.front(); - - } else { - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) - - __half result_d; - AlignedArray const *in_ptr_half = reinterpret_cast const *>(&input); - AlignedArray const *in_ptr_half2 = reinterpret_cast const *>(&input); - __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); - - // Set initial result = first half2, in case N==2 - __half2 tmp_result = x_in_half2[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < N/2; ++i) { - - tmp_result = __hadd2(x_in_half2[i], tmp_result); - - } - - result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); - - // One final step is needed for odd "N" (to add the (N-1)th element) - if( N%2 ){ - - __half last_element; - AlignedArray tmp_last; - AlignedArray *tmp_last_ptr = &tmp_last; - tmp_last_ptr[0] = in_ptr_half[N-1]; - last_element = reinterpret_cast<__half const &>(tmp_last); - - result_d = __hadd(result_d, last_element); - - } - - Array *result_ptr = &result; - *result_ptr = reinterpret_cast &>(result_d); - - #else - - Reduce< plus, half_t > scalar_reduce; - result.clear(); - - CUTLASS_PRAGMA_UNROLL - for (auto i = 0; i < N; ++i) { - - result[0] = scalar_reduce(result[0], input[i]); - - } - - #endif - } - - return result; - - } -}; -} -} -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h deleted file mode 100644 index 3792d332de65f19a1d30ba311d34073201176a3b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h +++ /dev/null @@ -1,235 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Kernel performing a reduction over densely packed tensors in global memory -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reduction { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mixed-precision reduction -template < - typename ElementAccumulator_, - typename Element_, - int Count = 1 -> -struct ReduceAdd { - - // - // Type definitions - // - - using ElementAccumulator = ElementAccumulator_; - using Element = Element_; - static int const kCount = Count; - - using FragmentAccumulator = cutlass::Array; - using FragmentElement = cutlass::Array; - - struct Params { }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - ReduceAdd(Params params_ = Params()): params(params_) { } - - /// Operator - CUTLASS_HOST_DEVICE - FragmentAccumulator operator()( - FragmentAccumulator accumulator, - FragmentElement element) const { - - plus op; - - NumericArrayConverter< - ElementAccumulator, - Element, - kCount, - PreferredRoundingMode::kRound> converter; - - return op(accumulator, converter(element)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Special handling for binary operators -template -struct VectorizeArrayOperation { - - using ValueType = Array; - - CUTLASS_HOST_DEVICE - ValueType operator()( - ReductionOp const &reduction_op, - ValueType const &lhs, - ValueType const &rhs) const { - - ValueType result; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = reduction_op(lhs[i], rhs[i]); - } - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct ReduceArrayOperation { - - using ArrayType = Array; - - CUTLASS_HOST_DEVICE - Element operator()( - ReductionOp const &reduction_op, - ArrayType const &array) const { - - Element item = reduction_op(array[0], array[1]); - - CUTLASS_PRAGMA_UNROLL - for (int i = 2; i < N; ++i) { - item = reduction_op(item, array[i]); - } - - return item; - } -}; - -template -struct ReduceArrayOperation, uint1b_t, N> { - - using ArrayType = Array; - - CUTLASS_HOST_DEVICE - uint1b_t operator()( - logical_and const &reduction_op, - ArrayType const &array) const { - - uint8_t const *ptr = reinterpret_cast(&array); - bool item = false; - - CUTLASS_PRAGMA_UNROLL - for (int byte = 0; byte < (N + 7) / 8; ++byte) { - uint8_t bits = ptr[byte]; - item = (item || !bits); - } - - return uint1b_t{!item}; - } -}; - -template -struct ReduceArrayOperation, uint1b_t, N> { - - using ArrayType = Array; - - CUTLASS_HOST_DEVICE - uint1b_t operator()( - logical_and const &reduction_op, - ArrayType const &array) const { - - uint8_t const *ptr = reinterpret_cast(&array); - bool item = true; - - CUTLASS_PRAGMA_UNROLL - for (int byte = 0; byte < (N + 7) / 8; ++byte) { - uint8_t bits = ptr[byte]; - item = (item || bits); - } - - return uint1b_t{item}; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper function to infer template argument types -template -CUTLASS_HOST_DEVICE -Array ApplyArrayOperator( - ReductionOp const &reduction_op, - Array const &lhs, - Array const &rhs) { - - VectorizeArrayOperation vectorize_op; - - return vectorize_op(reduction_op, lhs, rhs); -} - -/// Helper to reduce an array -template -Element ReduceArray(ReductionOp const &reduction_op, Array const &array) { - ReduceArrayOperation reduce_array_op; - - return reduce_array_op(reduction_op, array); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace reduction -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h deleted file mode 100644 index bbabaed2736cac7043671f10e9813a9a48b1916c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h +++ /dev/null @@ -1,67 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -* -**************************************************************************************************/ -/*! \file -\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation. -*/ -#pragma once -#include "cutlass/coord.h" - -namespace cutlass { -namespace reduction { -struct DefaultBlockSwizzle { - /// Ctor - CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} - - /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } - - /// - CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, - Coord<3> const &OutputTile) { - assert(OutputTile[0] == 1 && OutputTile[1] == 1); - assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0); - dim3 grid; - grid.x = problem_size[0] * problem_size[1] * problem_size[2] - / OutputTile[2] ; - return grid; - } - - /// - CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) { - assert(SubTile[0] == 1 && SubTile[1] == 1); - dim3 block = swizzle(); - Coord<3> threadblock_offset = - make_Coord(0, 0, block.x * SubTile[2]); - return threadblock_offset; - } -}; -} // namespace reduction -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h deleted file mode 100644 index 68bdb26e38b1a54843eb4883833ad6b8708f0aff..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h +++ /dev/null @@ -1,305 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Performs comparison between two elements with support for floating-point comparisons. -*/ - -#pragma once - -#include "numeric_types.h" -#include "complex.h" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE -bool relatively_equal(T a, T b, U epsilon, U nonzero_floor); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -// This floating-point comparison function implements the method described in -// -// https://floating-point-gui.de/errors/comparison/ -// -template -CUTLASS_HOST_DEVICE -bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { - -#if defined(__CUDACC_RTC__) - using cuda::std::abs; -#else - using std::abs; -#endif - - T abs_A = abs(a); - T abs_B = abs(b); - T diff = abs(a - b); - T zero = T(0); - - if (a == b) { - return true; - } - else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) { - return diff < epsilon * nonzero_floor; - } - - return diff < epsilon * (abs_A + abs_B); -} - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(bool a, bool b, bool, bool) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int8_t a, int8_t b, int8_t, int8_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint8_t a, uint8_t b, uint8_t, uint8_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int16_t a, int16_t b, int16_t, int16_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint16_t a, uint16_t b, uint16_t, uint16_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int32_t a, int32_t b, int32_t, int32_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint32_t a, uint32_t b, uint32_t, uint32_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(int64_t a, int64_t b, int64_t, int64_t) { - return (a == b); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { - return (a == b); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal( - bfloat16_t a, - bfloat16_t b, - bfloat16_t epsilon, - bfloat16_t nonzero_floor) { - - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal( - tfloat32_t a, - tfloat32_t b, - tfloat32_t epsilon, - tfloat32_t nonzero_floor) { - - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(double a, double b, double epsilon, double nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template -CUTLASS_HOST_DEVICE -bool relatively_equal(complex a, complex b, T epsilon, T nonzero_floor) { -#if defined(__CUDACC_RTC__) - using cuda::std::abs; -#else - using std::abs; -#endif - - T abs_A = abs(a); - T abs_B = abs(b); - T diff = abs(a - b); - complex zero = complex{T{}, T{}}; - - if (a == b) { - return true; - } - else if (a == zero || b == zero || diff < nonzero_floor) { - return diff < epsilon * nonzero_floor; - } - - return diff < epsilon * (abs_A + abs_B); -} - -template -CUTLASS_HOST_DEVICE -bool relatively_equal(complex a, complex b, complex epsilon, complex nonzero_floor) { -#if defined(__CUDACC_RTC__) - using cuda::std::abs; -#else - using std::abs; -#endif - - T abs_A = abs(a); - T abs_B = abs(b); - complex diff = a - b; - T abs_diff = abs(diff); - complex zero = complex{T{}, T{}}; - - if (a == b) { - return true; - } - else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) { - return abs_diff < abs(epsilon * nonzero_floor); - } - - return abs_diff < abs(epsilon) * (abs_A + abs_B); -} - - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -template <> -CUTLASS_HOST_DEVICE -bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) { - return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h deleted file mode 100644 index 09a0a1a4572775bbdbdba63a160952e35fef2c20..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h +++ /dev/null @@ -1,118 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/gemm/gemm.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// CTA-wide semaphore for inter-CTA synchronization. -class Semaphore { -public: - - int *lock; - bool wait_thread; - int state; - -public: - - /// Implements a semaphore to wait for a flag to reach a given value - CUTLASS_HOST_DEVICE - Semaphore(int *lock_, int thread_id): - lock(lock_), - wait_thread(thread_id < 0 || thread_id == 0), - state(-1) { - - } - - /// Permit fetching the synchronization mechanism early - CUTLASS_DEVICE - void fetch() { - if (wait_thread) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - #else - asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - #endif - } - } - - /// Gets the internal state - CUTLASS_DEVICE - int get_state() const { - return state; - } - - /// Waits until the semaphore is equal to the given value - CUTLASS_DEVICE - void wait(int status = 0) { - while( __syncthreads_and(state != status) ) { - fetch(); - } - - __syncthreads(); - } - - /// Updates the lock with the given result - CUTLASS_DEVICE - void release(int status = 0) { - __syncthreads(); - - if (wait_thread) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); - #else - asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); - #endif - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h deleted file mode 100644 index 6e98cdc3886b06626ea7d003122d62078f7767b9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h +++ /dev/null @@ -1,1388 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Provides a mechanism for packing and unpacking elements smaller than one byte -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/integer_subbyte.h" -#include "cutlass/fast_math.h" - -namespace cutlass { - -namespace detail { -// This is an implementation detail of cutlass::SubbyteReference and. -// cutlass::HostTensor. For a given logical element type Element, -// and its corresponding storage (physical) element type StorageUnit, -// it computes quantities that help with managing allocations. -// -// CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support -// packed arrays of subbyte types such as int4. Element is the "logical" type -// for computations, while CUTLASS uses StorageUnit as the element type -// of a packed array of Element. If Element is not a subbyte type, -// then the corresponding StorageUnit type is just Element itself. -// -// The ContainerType is always calculated as an array StorageUnit type (the StorageUnit -// is always a byte for subbyte types), -// and its number of bits is the lcm of the subbyte type's number of bits and 8. -// Below are some examples for different subbyte types. -// -// * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) -// * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) -template -struct StorageContainerCalculator { - // kContainerTypeNumBits: The number of bits needed for ContainerType - static constexpr int kContainerTypeNumBits = (sizeof_bits::value < 8) ? cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value) : sizeof_bits::value; - static_assert(kContainerTypeNumBits % sizeof_bits::value == 0, "The bits of ContainerType should be divisible by the element's number of bits"); - // kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance - static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits::value; - /// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance - static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8; - /// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType - static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits::value; - - static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero"); - static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero"); - static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero"); -}; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It -/// assumes these sub-byte elements are packed in a traditional C++ numeric type. -/// -/// The intended application is to provide a mechanism to indirectly reference elements in -/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller -/// than one byte. -/// -/// Supports basic pointer arithmetic: -/// -/// Example: -/// -/// int4b_t *ptr = ...; -/// -/// SubbyteReference ref = ptr; -/// ref += 15; -/// -/// int4b_t x = ref; // load an int4b_t -/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store -/// -template < - typename Element_, /// CUTLASS numeric element type. - typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer - /// number of objects of type Element. - class = void -> -class ConstSubbyteReference { -public: - - using Element = Element_; - using Storage = Storage_; - using StoragePointer = Storage const *; - - static_assert(sizeof_bits::value <= sizeof_bits::value, - "Size of Element must not be greater than Storage."); - - static_assert(!(sizeof_bits::value % sizeof_bits::value), - "Storage must be divisible by Element"); - -private: - - ///! Number of elements per storage vector - int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; - - ///! Bit mask - Storage const kMask = - ((sizeof_bits::value < sizeof_bits::value) ? - (Storage(1) << sizeof_bits::value) - Storage(1) : - ~Storage(0)); - -private: - - /// Pointer to array containing element - StoragePointer ptr_; - - /// Offset (in units of elements) from pointer. - /// - /// Invariant: must always be in range [0, kElementsPerVector) - int offset_; - -public: - - CUTLASS_HOST_DEVICE - ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } - - /// Constructor - CUTLASS_HOST_DEVICE - ConstSubbyteReference( - Element const *ptr, /// pointer to memory - int64_t offset /// logical offset in units of Element - ): - ptr_(reinterpret_cast(ptr)), - offset_(0) { - - int64_t offset_in_vectors = offset / kElementsPerVector; - int64_t offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = int(offset_in_elements); - } - - /// Constructor - CUTLASS_HOST_DEVICE - ConstSubbyteReference( - Element *ptr = nullptr - ): ConstSubbyteReference(ptr, 0) { } - - /// Gets storage pointer - CUTLASS_HOST_DEVICE - StoragePointer storage_pointer() const { - return ptr_; - } - - /// Gets element offset within storage vector - CUTLASS_HOST_DEVICE - int element_offset() const { - return offset_; - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - Element get() const { - Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); - return reinterpret_cast(item); - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - operator Element() const { - return get(); - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator+=(int offset) { - - offset += offset_; - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator+=(long long offset) { - - offset += offset_; - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator-=(int offset) { - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator-=(long long offset) { - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - return *this; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator+(int offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator+(long long offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator-(int offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator-=(long long offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Computes the difference in elements between references - CUTLASS_HOST_DEVICE - ptrdiff_t operator-(ConstSubbyteReference ref) const { - return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to signed 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator int64_t() const { - return int64_t(get()); - } - - /// Explicit cast to unsigned 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator uint64_t() const { - return uint64_t(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - - /// Explicit cast to double - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(get()); - } -}; - -template < - typename Element_, /// CUTLASS numeric element type. - typename Storage_ = /// Underlying storage type. Must be able to hold an integer - /// number of objects of type Element. - -#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads. - #if (__CUDA_ARCH__ >= 700) /// - uint16_t - #else - uint32_t - #endif -#else - uint8_t -#endif - , - class = void -> -class SubbyteReference { -public: - - using Element = Element_; - using Storage = Storage_; - using StoragePointer = Storage *; - - static_assert(sizeof_bits::value <= sizeof_bits::value, - "Size of Element must not be greater than Storage."); - - static_assert(!(sizeof_bits::value % sizeof_bits::value), - "Storage must be divisible by Element"); - -private: - - ///! Number of elements per storage vector - int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; - - ///! Bit mask - Storage const kMask = - ((sizeof_bits::value < sizeof_bits::value) ? - (Storage(1) << sizeof_bits::value) - Storage(1) : - ~Storage(0)); - -private: - - /// Pointer to array containing element - StoragePointer ptr_; - - /// Offset (in units of elements) from pointer. - /// - /// Invariant: must always be in range [0, kElementsPerVector) - int offset_; - -public: - - CUTLASS_HOST_DEVICE - SubbyteReference(): ptr_(nullptr), offset_(0) { } - - /// Constructor - CUTLASS_HOST_DEVICE - SubbyteReference( - Element *ptr, /// pointer to memory - int64_t offset /// logical offset in units of Element - ): - ptr_(reinterpret_cast(ptr)), - offset_(0) { - - int64_t offset_in_vectors = offset / kElementsPerVector; - int64_t offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = int(offset_in_elements); - } - - /// Constructor - CUTLASS_HOST_DEVICE - SubbyteReference( - Element *ptr = nullptr - ): SubbyteReference(ptr, 0) { } - - /// Gets storage pointer - CUTLASS_HOST_DEVICE - StoragePointer storage_pointer() const { - return ptr_; - } - - /// Gets storage pointer - CUTLASS_HOST_DEVICE - Element * operator&() const { - return reinterpret_cast(ptr_); - } - - /// Gets element offset within storage vector - CUTLASS_HOST_DEVICE - int element_offset() const { - return offset_; - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - Element get() const { - uint8_t const* byte_ptr = reinterpret_cast(ptr_); - // Convert offset in elements to offset in bytes - constexpr int elements_per_byte = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; - byte_ptr += offset_ / elements_per_byte; - // Offset of element within a byte - int byte_offset = offset_ % elements_per_byte; - uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits::value)) & kMask); - return reinterpret_cast(item); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference & set(Element const &x) { - - Storage item = (reinterpret_cast(x) & kMask); - Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); - Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits::value)); - -#if defined(__CUDA_ARCH__) - - // - // Homebrew read-modify-write - // - Storage original; - Storage updated; - - do { - - original = (*ptr_); - - updated = Storage((original & kUpdateMask) | new_bits); - - original = atomicCAS(ptr_, original, updated); - - } while (updated != original); - -#else - - Storage original = (*ptr_); - Storage updated = Storage((original & kUpdateMask) | new_bits); - *ptr_ = updated; - -#endif - - return *this; - } - - //// - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - operator Element() const { - return get(); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=(Element const & x) { - return set(x); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=(SubbyteReference const & x) { - return set(x.get()); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=( - ConstSubbyteReference const &x) { - return set(x.get()); - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator+=(int offset) { - - offset += offset_; - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator+=(long long offset) { - - offset += offset_; - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator-=(int offset) { - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator-=(long long offset) { - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - return *this; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator+(int offset) const { - - SubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator+(long long offset) const { - - SubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator-(int offset) const { - - SubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator-=(long long offset) const { - - SubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Computes the difference in elements between references - CUTLASS_HOST_DEVICE - ptrdiff_t operator-(SubbyteReference ref) const { - return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to signed 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator int64_t() const { - return int64_t(get()); - } - - /// Explicit cast to unsigned 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator uint64_t() const { - return uint64_t(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - - /// Explicit cast to double - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(get()); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template using _war = T; -template < - typename Element_, /// CUTLASS numeric element type. - typename Storage_ /// Underlying basic storage type. -> -class SubbyteReference::value % sizeof_bits::value != 0>::type> { -public: - - using Element = Element_; - /// Note: It's possible that StorageUnit is not divisible by Element. - /// For example, an Element instance might be stored across 2 StorageUnit instances. - /// Thus, CUTLASS needs a storage vector to hold an integer number of Element instances. - - using StorageUnit = Storage_; -private: - using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; -public: - static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits; - static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit; - - using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; - using StorageVecPointer = StorageVec *; - - using CudaAtomicType = typename platform::conditional< - sizeof_bits::value == 16, - uint32_t, - uint64_t - >::type; - - static_assert(sizeof_bits::value <= sizeof_bits::value, - "Size of Element must not be greater than StorageVec."); - - static_assert(!(sizeof_bits::value % sizeof_bits::value), - "StorageVec must be divisible by Element"); - -private: - - ///! Number of elements per storage vector - int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; - - ///! Bit mask for storage unit. - StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); - - /// Pointer to array containing element - _war ptr_; - - /// Offset (in units of elements) from pointer. - /// - /// Invariant: must always be in range [0, kElementsPerVector) - int offset_; - - /// Element may be stored across 2 storage unit. - /// Low storage unit index in StorageVec - /// High storage unit index in StorageVec - int low_storage_unit_idx_; - int high_storage_unit_idx_; - - /// Full Mask to extract the entire element - uint64_t full_element_mask_; - - /// Mask to extract the Element from Low storage unit and High storage unit. - StorageUnit low_storage_mask_; - StorageUnit high_storage_mask_; - - /// Start bit index inside the storage unit. - int start_bit_idx_; - -private: - - CUTLASS_HOST_DEVICE - void update_element_status() { - int num_bits = offset_ * sizeof_bits::value; - - start_bit_idx_ = num_bits % sizeof_bits::value; - - low_storage_unit_idx_ = num_bits / sizeof_bits::value; - high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value - ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; - - full_element_mask_ = uint64_t(kMask) << start_bit_idx_; - low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); - high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); - } - -public: - - CUTLASS_HOST_DEVICE - SubbyteReference(): ptr_(nullptr), offset_(0) { } - - /// Constructor - CUTLASS_HOST_DEVICE - SubbyteReference( - Element *ptr, /// pointer to memory - int64_t offset /// logical offset in units of Element - ): - ptr_(reinterpret_cast(ptr)), - offset_(0) { - int64_t offset_in_vectors = offset / kElementsPerVector; - int64_t offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = int(offset_in_elements); - - update_element_status(); - } - - /// Constructor - CUTLASS_HOST_DEVICE - SubbyteReference( - Element *ptr = nullptr - ): SubbyteReference(ptr, 0) { } - - /// Gets StorageVec pointer - CUTLASS_HOST_DEVICE - StorageVecPointer storage_pointer() const { - return ptr_; - } - - /// Gets StorageVec pointer - CUTLASS_HOST_DEVICE - Element * operator&() const { - return reinterpret_cast(ptr_); - } - - /// Gets element offset within StorageVec vector - CUTLASS_HOST_DEVICE - int element_offset() const { - return offset_; - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - Element get() const { - StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; - StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; - - uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; - uint8_t result = uint8_t(full_item >> start_bit_idx_); - - return reinterpret_cast(result); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference & set(Element const &x) { - - uint64_t item = static_cast((reinterpret_cast(x) & kMask)) << start_bit_idx_; - - StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0)); - StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits::value); - - StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0))); - StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits::value) & (~StorageUnit(0))); - -#if defined(__CUDA_ARCH__) - // - // Homebrew read-modify-write - // - if(high_storage_unit_idx_ != low_storage_unit_idx_){ - /// Only need update 2 storage unit at once. - /// consider misaligned address issue, we need to do atomicCAS twice - StorageUnit original_low_bits, original_high_bits, update_low_bits, update_high_bits; - do { - original_low_bits = ((*ptr_)[low_storage_unit_idx_]); - update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; - original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits); - } while (update_low_bits != original_low_bits); - do { - original_high_bits = ((*ptr_)[high_storage_unit_idx_]); - update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; - original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits); - } while (update_high_bits != original_high_bits); - } - else { - /// Only need update 1 storage unit. - StorageUnit original, updated; - do { - original = ((*ptr_)[low_storage_unit_idx_]); - - updated = (original & kLowUpdateMask) | low_new_bits; - - original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated); - - } while (updated != original); - } -#else - - - StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits; - StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits; - - (*ptr_)[low_storage_unit_idx_] = update_low_bits; - - if(low_storage_unit_idx_ != high_storage_unit_idx_) - (*ptr_)[high_storage_unit_idx_] = update_high_bits; -#endif - - return *this; - } - - //// - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - operator Element() const { - return get(); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=(Element const & x) { - return set(x); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=(SubbyteReference const & x) { - return set(x.get()); - } - - /// Stores an element to memory - CUTLASS_HOST_DEVICE - SubbyteReference &operator=( - ConstSubbyteReference const &x) { - return set(x.get()); - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator+=(int offset) { - - offset += offset_; - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - update_element_status(); - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator+=(long long offset) { - - offset += offset_; - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - update_element_status(); - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator-=(int offset) { - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - update_element_status(); - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - SubbyteReference &operator-=(long long offset) { - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - update_element_status(); - return *this; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator+(int offset) const { - - SubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator+(long long offset) const { - - SubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator-(int offset) const { - - SubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - SubbyteReference operator-=(long long offset) const { - - SubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Computes the difference in elements between references - CUTLASS_HOST_DEVICE - ptrdiff_t operator-(SubbyteReference ref) const { - return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to signed 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator int64_t() const { - return int64_t(get()); - } - - /// Explicit cast to unsigned 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator uint64_t() const { - return uint64_t(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - - /// Explicit cast to double - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(get()); - } -}; - -template using _war = T; -template < - typename Element_, /// CUTLASS numeric element type. - typename Storage_ /// Underlying storage type. Must be able to hold an integer -> -class ConstSubbyteReference::value % sizeof_bits::value != 0>::type> { -public: - - using Element = Element_; - ///! Note: Storage unit could not be divisibale by Element, - /// Type element may be stored across 2 storage units, so need a storage vector to hold integer - /// number of objects of type Element. - using StorageUnit = Storage_; - static int const kBitsStoredVec = cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value); - static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; - - using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; - using StorageVecPointer = StorageVec const *; - - using CudaAtomicType = typename platform::conditional< - sizeof_bits::value == 16, - uint32_t, - uint64_t - >::type; - - static_assert(sizeof_bits::value <= sizeof_bits::value, - "Size of Element must not be greater than StorageVec."); - - static_assert(!(sizeof_bits::value % sizeof_bits::value), - "StorageVec must be divisible by Element"); - -private: - - ///! Number of elements per storage vector - int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; - - ///! Bit mask for storage unit. - StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); - - /// Pointer to array containing element - _war ptr_; - - /// Offset (in units of elements) from pointer. - /// - /// Invariant: must always be in range [0, kElementsPerVector) - int offset_; - - /// Element may be stored across 2 storage unit. - /// Low storage unit index in StorageVec - /// High storage unit index in StorageVec - int low_storage_unit_idx_; - int high_storage_unit_idx_; - - /// Full Mask to extract the entire element - uint64_t full_element_mask_; - - /// Mask to extract the Element from Low storage unit and High storage unit. - StorageUnit low_storage_mask_; - StorageUnit high_storage_mask_; - - /// Start bit index inside the storage unit. - int start_bit_idx_; - -private: - - CUTLASS_HOST_DEVICE - void update_element_status() { - int num_bits = offset_ * sizeof_bits::value; - - start_bit_idx_ = num_bits % sizeof_bits::value; - - low_storage_unit_idx_ = num_bits / sizeof_bits::value; - high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value - ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; - - full_element_mask_ = uint64_t(kMask) << start_bit_idx_; - low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); - high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); - } - -public: - - CUTLASS_HOST_DEVICE - ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } - - /// Constructor - CUTLASS_HOST_DEVICE - ConstSubbyteReference( - Element const *ptr, /// pointer to memory - int64_t offset /// logical offset in units of Element - ): - ptr_(reinterpret_cast(ptr)), - offset_(0) { - - int64_t offset_in_vectors = offset / kElementsPerVector; - int64_t offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = int(offset_in_elements); - - update_element_status(); - } - - /// Constructor - CUTLASS_HOST_DEVICE - ConstSubbyteReference( - Element *ptr = nullptr - ): ConstSubbyteReference(ptr, 0) { } - - /// Gets storage pointer - CUTLASS_HOST_DEVICE - StorageVecPointer storage_pointer() const { - return ptr_; - } - - /// Gets element offset within storage vector - CUTLASS_HOST_DEVICE - int element_offset() const { - return offset_; - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - Element get() const { - StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; - StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; - - uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; - uint8_t result = uint8_t(full_item >> start_bit_idx_); - - return reinterpret_cast(result); - } - - /// Unpacks an element from memory - CUTLASS_HOST_DEVICE - operator Element() const { - return get(); - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator+=(int offset) { - - offset += offset_; - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - update_element_status(); - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator+=(long long offset) { - - offset += offset_; - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ += offset_in_vectors; - offset_ = offset_in_elements; - - update_element_status(); - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator-=(int offset) { - - int offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = offset % kElementsPerVector; - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - update_element_status(); - - return *this; - } - - /// Adds an offset in units of elements to the reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference &operator-=(long long offset) { - - long long offset_in_vectors = offset / kElementsPerVector; - int offset_in_elements = int(offset % kElementsPerVector); - - ptr_ -= offset_in_vectors; - offset_ -= offset_in_elements; - - if (offset_ < 0) { - offset_ += kElementsPerVector; - --ptr_; - } - - update_element_status(); - - return *this; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator+(int offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator+(long long offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref += offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator-(int offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Returns a reference to an element with a given offset from the current reference - CUTLASS_HOST_DEVICE - ConstSubbyteReference operator-=(long long offset) const { - - ConstSubbyteReference ref(ptr_, offset_); - ref -= offset; - - return ref; - } - - /// Computes the difference in elements between references - CUTLASS_HOST_DEVICE - ptrdiff_t operator-(ConstSubbyteReference ref) const { - return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); - } - - /// Explicit cast to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to signed 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator int64_t() const { - return int64_t(get()); - } - - /// Explicit cast to unsigned 64-bit integer - CUTLASS_HOST_DEVICE - explicit operator uint64_t() const { - return uint64_t(get()); - } - - /// Explicit cast to float - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - - /// Explicit cast to double - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(get()); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template ::value < 8)> -struct ReferenceFactory; - -template -struct ReferenceFactory { - - ///! Number of elements per storage vector - static int const kElementsPerVector = 1; - - CUTLASS_HOST_DEVICE - static Element &get(Element *ptr, int64_t offset) { - return ptr[offset]; - } - - CUTLASS_HOST_DEVICE - static Element const &get(Element const *ptr, int64_t offset) { - return ptr[offset]; - } - - CUTLASS_HOST_DEVICE - static Element *add_pointer_offset(Element *ptr, int64_t offset) { - return ptr + offset; - } - - CUTLASS_HOST_DEVICE - static Element const *add_pointer_offset(Element const *ptr, int64_t offset) { - return ptr + offset; - } -}; - -template -struct ReferenceFactory { - - // - // Static methods - // - - CUTLASS_HOST_DEVICE - static SubbyteReference get(Element *ptr, int64_t offset) { - return SubbyteReference(ptr, offset); - } - - CUTLASS_HOST_DEVICE - static ConstSubbyteReference get(Element const *ptr, - int64_t offset) { - return ConstSubbyteReference(ptr, offset); - } - - /// Helper to add an offset in number of elements, assuming this offset is divisible - /// by the vector size. - CUTLASS_HOST_DEVICE - static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) { - return &SubbyteReference(ptr, offset_in_elements); - } - - /// Helper to add an offset in number of elements, assuming this offset is divisible - /// by the vector size. - CUTLASS_HOST_DEVICE - static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) { - return &ConstSubbyteReference(ptr, offset_in_elements); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h deleted file mode 100644 index a124d395cf2222331e0ceb160271b1621688fd6f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h +++ /dev/null @@ -1,326 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a canonical coordinate for rank=4 tensors offering named indices. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a canonical 4D coordinate used by tensor operations. -struct Tensor4DCoord : public Coord<4> { - - /// Base class - using Base = Coord<4>; - - /// Index type - using Index = typename Base::Index; - - /// LongIndex type - using LongIndex = typename Base::LongIndex; - - /// Batch dimension - static int const kN = 0; - - /// Height dimension - static int const kH = 1; - - /// Width dimension - static int const kW = 2; - - /// Channels dimension - static int const kC = 3; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Tensor4DCoord() { } - - /// Constructs from Coord<4> - CUTLASS_HOST_DEVICE - Tensor4DCoord(Coord<4> const &coord): Base(coord) { } - - /// Helper to construct from N, H, W, and C. - CUTLASS_HOST_DEVICE - Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } - - /// Helper to construct from N, H, W, and C, which are LongIndex type - CUTLASS_HOST_DEVICE - Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) - : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index const & n() const { return this->at(kN); } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index const & h() const { return this->at(kH); } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index & h() { return this->at(kH); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index const & w() const { return this->at(kW); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index & w() { return this->at(kW); } - - /// Returns the channel of the coordinate - CUTLASS_HOST_DEVICE - Index const & c() const { return this->at(kC); } - - /// Returns the channel of the coordinate - CUTLASS_HOST_DEVICE - Index & c() { return this->at(kC); } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - Tensor4DCoord operator+(Base const& b) const { - return Tensor4DCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - Tensor4DCoord operator-(Base const& b) const { - return Tensor4DCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - Tensor4DCoord operator*(Base const& b) const { - return Tensor4DCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - Tensor4DCoord operator/(Base const& b) const { - return Tensor4DCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - Tensor4DCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - Tensor4DCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - Tensor4DCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - Tensor4DCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a canonical 5D coordinate used by tensor operations. -struct Tensor5DCoord : public Coord<5> { - - /// Base class - using Base = Coord<5>; - - /// Index type - using Index = typename Base::Index; - - /// LongIndex type - using LongIndex = typename Base::LongIndex; - - /// Batch dimension - static int const kN = 0; - - /// Depth dimension - static int const kD = 1; - - /// Height dimension - static int const kH = 2; - - /// Width dimension - static int const kW = 3; - - /// Channels dimension - static int const kC = 4; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Tensor5DCoord() { } - - /// Constructs from Coord<5> - CUTLASS_HOST_DEVICE - Tensor5DCoord(Coord<5> const &coord): Base(coord) { } - - /// Helper to construct from N, D, H, W, and C. - CUTLASS_HOST_DEVICE - Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } - - /// Helper to construct from N, D, H, W, and C, which are LongIndex type - CUTLASS_HOST_DEVICE - Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) - : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index const & n() const { return this->at(kN); } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index const & d() const { return this->at(kD); } - - /// Returns the batch of the coordinate - CUTLASS_HOST_DEVICE - Index & d() { return this->at(kD); } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index const & h() const { return this->at(kH); } - - /// Returns the row of the coordinate - CUTLASS_HOST_DEVICE - Index & h() { return this->at(kH); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index const & w() const { return this->at(kW); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index & w() { return this->at(kW); } - - /// Returns the channel of the coordinate - CUTLASS_HOST_DEVICE - Index const & c() const { return this->at(kC); } - - /// Returns the channel of the coordinate - CUTLASS_HOST_DEVICE - Index & c() { return this->at(kC); } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - Tensor5DCoord operator+(Base const& b) const { - return Tensor5DCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - Tensor5DCoord operator-(Base const& b) const { - return Tensor5DCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - Tensor5DCoord operator*(Base const& b) const { - return Tensor5DCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - Tensor5DCoord operator/(Base const& b) const { - return Tensor5DCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - Tensor5DCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - Tensor5DCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - Tensor5DCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - Tensor5DCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h deleted file mode 100644 index fc467499996a00645b0a936efe741ece2092fb90..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h +++ /dev/null @@ -1,419 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -*/ -#pragma once - - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" -#include "cutlass/platform/platform.h" -#include "cutlass/subbyte_reference.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Default layout function from coordinates in a tensor's index space into the n-D array held -/// in memory. -/// -/// All layout functions must define at least the members shown in IdentityTensorLayout<>. -template -class IdentityTensorLayout { -public: - /// Logical rank of tensor - static int const kRank = Rank; - - /// Rank of stride vector - static int const kStrideRank = Rank; - - /// Index type used for coordinates - using Index = int32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using TensorCoord = Coord; - - /// Stride vector - using Stride = Coord; - -private: - - // - // Data members - // - - /// Stride data member - Stride stride_; - -public: - - // - // Methods - // - - CUTLASS_HOST_DEVICE - IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { } - - /// Returns the offset of a coordinate in linear memory - CUTLASS_HOST_DEVICE - LongIndex operator()(Coord const &coord) const { - return coord.dot(stride_); - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride stride() const { - return stride_; - } - - /// Returns the stride of the layout - CUTLASS_HOST_DEVICE - Stride & stride() { - return stride_; - } - - /// Compute the number of contiguous elements needed to store a tensor with the given size - CUTLASS_HOST_DEVICE - LongIndex capacity(TensorCoord const &size) const { - int idx = stride_.max_dim_index(); - return stride_[idx] * size[idx]; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank - and layout within memory. A TensorRef combines a pointer and a Layout concept - - Examples: - - (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h) - - 1. Column-major matrix may be represented as a rank=2 tensor: - - TensorRef A(ptr_A, ldm); - - 2. Row-major matrix may be represented as a rank=2 tensor: - - TensorRef B(ptr_A, ldm); - - 3. An interleaved matrix may be represented as a rank=2 tensor: - - TensorRef > C; - - 4. A helper exists to define a TensorRef for a contiguous matrix whose layout - is not known at compile time. - - int ldm; // leading dimension - layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor - - - TensorRef E(ptr_E, {ldm, kind}); - -*/ -template < - /// Data type of element stored within tensor (concept: NumericType) - typename Element_, - /// Defines a mapping from logical coordinate to linear memory (concept: Layout) - typename Layout_ -> -class TensorRef { - public: - /// Data type of individual access - using Element = Element_; - - /// Mapping function from logical coordinate to linear memory - using Layout = Layout_; - - /// Reference type to an element - using Reference = typename platform::conditional< - sizeof_bits::value >= 8, - Element &, - SubbyteReference - >::type; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Layout's stride vector - using Stride = typename Layout::Stride; - - /// TensorRef to constant data - using ConstTensorRef = TensorRef< - typename platform::remove_const::type const, - Layout>; - - /// TensorRef to non-constant data - using NonConstTensorRef = TensorRef< - typename platform::remove_const::type, - Layout>; - - /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a - /// scalar, but degenerate cases such as these are difficult to accommodate without - /// extensive C++ metaprogramming or support for zero-length arrays. - static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); - - private: - - /// Pointer - Element* ptr_; - - /// Layout object maps logical coordinates to linear offsets - Layout layout_; - - public: - - // - // Methods - // - - /// Constructs a TensorRef with a pointer and layout object. - CUTLASS_HOST_DEVICE - TensorRef(): ptr_(nullptr) { - - } - - /// Constructs a TensorRef with a pointer and layout object. - CUTLASS_HOST_DEVICE - TensorRef( - Element *ptr, ///< pointer to start of tensor - Layout const &layout ///< layout object containing stride and mapping function - ): - ptr_(ptr), layout_(layout) { - - } - - /// Converting constructor from TensorRef to non-constant data. - template - CUTLASS_HOST_DEVICE - TensorRef( - NonConstTensorRef const &ref, ///< TensorRef to non-const data - ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const - _Magic magic = (typename platform::enable_if< ! platform::is_same >::value, _Magic>::type)0 - ): - ptr_(ref.data()), layout_(ref.layout()) { } - - /// Returns a reference to constant-valued tensor. - CUTLASS_HOST_DEVICE - ConstTensorRef const_ref() const { - return ConstTensorRef(ptr_, layout_); - } - - CUTLASS_HOST_DEVICE - NonConstTensorRef non_const_ref() const { - return NonConstTensorRef(const_cast::type *>(ptr_), layout_); - } - - /// Updates only the pointer - CUTLASS_HOST_DEVICE - void reset(Element* ptr = nullptr) { - ptr_ = ptr; - } - - /// Updates the pointer and layout object - CUTLASS_HOST_DEVICE - void reset(Element* ptr, Layout const &layout) { - ptr_ = ptr; - layout_ = layout; - } - - /// Returns true if the TensorRef is non-null - CUTLASS_HOST_DEVICE - bool good() const { - return ptr_ != nullptr; - } - - /// Returns the pointer to referenced data - CUTLASS_HOST_DEVICE - Element * data() const { return ptr_; } - - /// Returns a reference to the element at a given linear index - CUTLASS_HOST_DEVICE - Reference data(LongIndex idx) const { - return ReferenceFactory::type, - (sizeof_bits::value < 8)>::get(ptr_, idx); - } - - /// Returns the layout object - CUTLASS_HOST_DEVICE - Layout & layout() { - return layout_; - } - - /// Returns the layout object - CUTLASS_HOST_DEVICE - Layout layout() const { - return layout_; - } - - /// Returns the layout object's stride vector - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the layout object's stride vector - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Returns the layout object's stride in a given physical dimension - CUTLASS_HOST_DEVICE - typename Layout::Stride::Index stride(int dim) const { - return layout_.stride().at(dim); - } - - /// Returns the layout object's stride in a given physical dimension - CUTLASS_HOST_DEVICE - typename Layout::Stride::Index & stride(int dim) { - return layout_.stride().at(dim); - } - - /// Computes the offset of an index from the origin of the tensor - CUTLASS_HOST_DEVICE - LongIndex offset(TensorCoord const& coord) const { - return layout_(coord); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference at(TensorCoord const& coord) const { - return data(offset(coord)); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference operator[](TensorCoord const& coord) const { - return data(offset(coord)); - } - - /// Adds an offset to each pointer - CUTLASS_HOST_DEVICE - TensorRef & add_pointer_offset(LongIndex offset_) { - ptr_ = ReferenceFactory::type, - (sizeof_bits::value < 8)>::add_pointer_offset(ptr_, offset_); - return *this; - } - - /// Adds an offset to each pointer - CUTLASS_HOST_DEVICE - TensorRef & add_coord_offset(TensorCoord const &coord) { - add_pointer_offset(offset(coord)); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRef operator+(TensorCoord const& b) const { - TensorRef result(*this); - result.add_coord_offset(b); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRef & operator+=(TensorCoord const& b) { - add_coord_offset(b); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRef operator-(TensorCoord const& b) const { - TensorRef result(*this); - result.add_pointer_offset(-offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRef & operator-=(TensorCoord const& b) { - add_pointer_offset(-offset(b)); - return *this; - } -}; - -/// Constructs a TensorRef, deducing types from arguments. -template < - typename Element, - typename Layout -> -CUTLASS_HOST_DEVICE -TensorRef make_TensorRef(Element *ptr, Layout const &layout) { - return TensorRef(ptr, layout); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations to handle degenerate and sub-byte cases. -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Element, - typename Layout -> -CUTLASS_HOST_DEVICE -bool TensorRef_aligned(TensorRef const &ref, int alignment) { - - int const kStrideRank = Layout::kStrideRank; - - if (reinterpret_cast(ref.data()) % alignment) { - return false; - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kStrideRank; ++i) { - if (ref.stride(i) % alignment) { - return false; - } - } - - return true; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h deleted file mode 100644 index 9ba3a2308081e8c4b11d18cb8125ec7943e534f0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h +++ /dev/null @@ -1,374 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -*/ -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/tensor_ref.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct PlanarComplexReference { - - // - // Type definitions - // - - using Element = Element_; - using ComplexElement = complex; - - // - // Data members - // - - Element *real; - Element *imag; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - PlanarComplexReference( - Element *real_ = nullptr, - Element *imag_ = nullptr - ): - real(real_), imag(imag_) { } - - /// Loads the complex element - CUTLASS_HOST_DEVICE - operator complex() const { - return complex{*real, *imag}; - } - - /// Stores a complex element to the location pointed to by the reference - CUTLASS_HOST_DEVICE - PlanarComplexReference &operator=(complex const &rhs) { - *real = rhs.real(); - *imag = rhs.imag(); - return *this; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank - and layout within memory. A TensorRef combines a pointer and a Layout concept - -*/ -template < - /// Data type of element stored within tensor (concept: NumericType) - typename Element_, - /// Defines a mapping from logical coordinate to linear memory (concept: Layout) - typename Layout_ -> -class TensorRefPlanarComplex { - public: - /// Data type of individual access - using Element = Element_; - - /// Complex element type - using ComplexElement = complex; - - /// Mapping function from logical coordinate to linear memory - using Layout = Layout_; - - static_assert(sizeof_bits::value >= 8, - "Planar complex not suitable for subbyte elements at this time"); - - /// Reference type to an element - using Reference = PlanarComplexReference; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Layout's stride vector - using Stride = typename Layout::Stride; - - /// TensorRef to constant data - using ConstTensorRef = TensorRefPlanarComplex< - typename platform::remove_const::type const, - Layout>; - - /// TensorRef to non-constant data - using NonConstTensorRef = TensorRefPlanarComplex< - typename platform::remove_const::type, - Layout>; - - /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a - /// scalar, but degenerate cases such as these are difficult to accommodate without - /// extensive C++ metaprogramming or support for zero-length arrays. - static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); - - private: - - /// Pointer - Element* ptr_; - - /// Layout object maps logical coordinates to linear offsets - Layout layout_; - - /// Offset to imaginary part - LongIndex imaginary_stride_; - - public: - - // - // Methods - // - - /// Constructs a TensorRef with a pointer and layout object. - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex( - Element *ptr = nullptr, ///< pointer to start of tensor - Layout const &layout = Layout(), ///< layout object containing stride and mapping function - LongIndex imaginary_stride = 0 - ): - ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { - - } - - /// Converting constructor from TensorRef to non-constant data. - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex( - NonConstTensorRef const &ref ///< TensorRef to non-const data - ): - ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } - - /// Returns a reference to constant-valued tensor. - CUTLASS_HOST_DEVICE - ConstTensorRef const_ref() const { - return ConstTensorRef(ptr_, layout_, imaginary_stride_); - } - - CUTLASS_HOST_DEVICE - NonConstTensorRef non_const_ref() const { - return NonConstTensorRef( - const_cast::type *>(ptr_), - layout_, - imaginary_stride_); - } - - /// Updates only the pointer - CUTLASS_HOST_DEVICE - void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { - ptr_ = ptr; - imaginary_stride_ = imaginary_stride; - } - - /// Updates the pointer and layout object - CUTLASS_HOST_DEVICE - void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { - ptr_ = ptr; - layout_ = layout; - imaginary_stride_ = imaginary_stride; - } - - /// Returns true if the TensorRef is non-null - CUTLASS_HOST_DEVICE - bool good() const { - return ptr_ != nullptr; - } - - /// Returns the pointer to referenced data - CUTLASS_HOST_DEVICE - Element * data() const { return ptr_; } - - /// Returns the pointer to referenced data - CUTLASS_HOST_DEVICE - Element * imaginary_data() const { return ptr_ + imaginary_stride_; } - - /// Returns a reference to the element at a given linear index - CUTLASS_HOST_DEVICE - Reference data(LongIndex idx) const { - return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); - } - - /// Returns the layout object - CUTLASS_HOST_DEVICE - Layout & layout() { - return layout_; - } - - /// Returns the layout object - CUTLASS_HOST_DEVICE - Layout layout() const { - return layout_; - } - - /// Gets the stride to an imaginary element - LongIndex imaginary_stride() const { - return imaginary_stride_; - } - - /// Gets the stride to an imaginary element - LongIndex &imaginary_stride() { - return imaginary_stride_; - } - - /// Returns the layout object's stride vector - CUTLASS_HOST_DEVICE - Stride stride() const { - return layout_.stride(); - } - - /// Returns the layout object's stride vector - CUTLASS_HOST_DEVICE - Stride & stride() { - return layout_.stride(); - } - - /// Returns the layout object's stride in a given physical dimension - CUTLASS_HOST_DEVICE - Index stride(int dim) const { - return layout_.stride().at(dim); - } - - /// Returns the layout object's stride in a given physical dimension - CUTLASS_HOST_DEVICE - Index & stride(int dim) { - return layout_.stride().at(dim); - } - - /// Computes the offset of an index from the origin of the tensor - CUTLASS_HOST_DEVICE - LongIndex offset(TensorCoord const& coord) const { - return layout_(coord); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference at(TensorCoord const& coord) const { - return data(offset(coord)); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference operator[](TensorCoord const& coord) const { - return data(offset(coord)); - } - - /// Adds an offset to each pointer - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { - ptr_ += offset_; - return *this; - } - - /// Adds an offset to each pointer - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { - add_pointer_offset(offset(coord)); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex operator+(TensorCoord const& b) const { - TensorRefPlanarComplex result(*this); - result.add_coord_offset(b); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex & operator+=(TensorCoord const& b) { - add_coord_offset(b); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex operator-(TensorCoord const& b) const { - TensorRefPlanarComplex result(*this); - result.add_pointer_offset(-offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorRefPlanarComplex & operator-=(TensorCoord const& b) { - add_pointer_offset(-offset(b)); - return *this; - } - - /// TensorRef to real-valued tensor - CUTLASS_HOST_DEVICE - cutlass::TensorRef ref_real() const { - return cutlass::TensorRef(data(), layout()); - } - - /// TensorRef to real-valued tensor - CUTLASS_HOST_DEVICE - cutlass::TensorRef ref_imag() const { - return cutlass::TensorRef(imaginary_data(), layout()); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Constructs a TensorRef, deducing types from arguments. -template < - typename Element, - typename Layout -> -CUTLASS_HOST_DEVICE -TensorRefPlanarComplex make_TensorRefPlanarComplex( - Element *ptr, - Layout const &layout, - int64_t imaginary_stride) { - - return TensorRefPlanarComplex(ptr, layout, imaginary_stride); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h deleted file mode 100644 index d669443abd8b5b246a9d2aaf2ce4dd91f782f948..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h +++ /dev/null @@ -1,297 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a structure containing strides and a pointer to tensor data. - - TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, - it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from - data storage and is therefore lightweight and may be embedded in larger tensor objects or - memory structures. - - See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to - linear memory. -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Data type of element stored within tensor - typename Element_, - /// Maps a Coord in the logical tensor index space to the internal n-D array - typename Layout_ -> -class TensorView : public TensorRef { - public: - - /// Base tensor reference - using Base = cutlass::TensorRef; - - /// Mapping function from logical coordinate to internal n-D array - using Layout = Layout_; - - /// TensorRef pointing to constant memory - using ConstTensorRef = typename Base::ConstTensorRef; - - /// Underlying TensorRef type - using TensorRef = Base; - - /// Data type of individual access - using Element = Element_; - - /// Reference type to an element - using Reference = Element &; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Coordinate in storage n-D array - using Stride = typename Layout::Stride; - - /// TensorView pointing to constant memory - using ConstTensorView = TensorView< - typename platform::remove_const::type const, - Layout>; - - /// TensorView pointing to non-constant memory - using NonConstTensorView = TensorView< - typename platform::remove_const::type, - Layout>; - - /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a - /// scalar, but degenerate cases such as these are difficult to accommodate without - /// extensive C++ metaprogramming or support for zero-length arrays. - static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); - - private: - - /// View extent - TensorCoord extent_; - - public: - - // - // Methods - // - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorView() { } - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorView( - Element *ptr, ///< pointer to start of tensor - Layout const &layout, ///< layout object containing stride and mapping function - TensorCoord const &extent ///< size of the view in logical coordinates - ): - Base(ptr, layout), extent_(extent) { - - } - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorView( - TensorRef const &ref, ///< pointer and layout object referencing a tensor - TensorCoord const &extent ///< logical size of tensor - ): - Base(ref), extent_(extent) { - - } - - /// Converting constructor from TensorRef to non-constant data. - CUTLASS_HOST_DEVICE - TensorView( - NonConstTensorView const &view ///< TensorView to non-const data - ): - Base(view), extent_(view.extent_) { } - - /// Updates the pointer and layout object - CUTLASS_HOST_DEVICE - void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { - Base::reset(ptr, layout); - this->resize(extent); - } - - /// Updates the pointer - CUTLASS_HOST_DEVICE - void reset(Element* ptr) { - Base::reset(ptr); - } - - /// Changes the size of the view without affecting pointer or layout - CUTLASS_HOST_DEVICE - void resize(TensorCoord const &extent) { - this->extent_ = extent; - } - - /// Returns the extent of the view (the size along each logical dimension). - CUTLASS_HOST_DEVICE - TensorCoord const& extent() const { return extent_; } - - /// Returns the extent along a particular logical dimension. - CUTLASS_HOST_DEVICE - Index extent(int dim) const { return extent_.at(dim); } - - /// Returns the number of logical elements - CUTLASS_HOST_DEVICE - LongIndex size() const { - return extent_.product(); - } - - /// Determines whether a location is within a tensor - CUTLASS_HOST_DEVICE - bool contains(TensorCoord const& coord) const { - CUTLASS_PRAGMA_UNROLL - for (int dim = 0; dim < kRank; ++dim) { - if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { - return false; - } - } - return true; - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - TensorRef ref() const { - return TensorRef(this->data(), this->layout()); - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - ConstTensorRef const_ref() const { - return ConstTensorRef(this->data(), this->layout()); - } - - /// Returns a TensorView to const data - CUTLASS_HOST_DEVICE - ConstTensorView const_view() const { - return ConstTensorView(const_ref(), extent_); - } - - /// Returns a Tensor_view given location and size quantities - CUTLASS_HOST_DEVICE - TensorView subview( - TensorCoord extent, ///< extent of the resulting view - TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view - ) const { - - TensorView result(this->ref(), extent.clamp(extent_ - location)); - result.add_coord_offset(location); - return result; - } - - /// Returns the number of scalar elements needed to store tensor. - CUTLASS_HOST_DEVICE - size_t capacity() const { - return Base::layout().capacity(extent_); - } - - /// Returns a TensorView offset by a given amount - CUTLASS_HOST_DEVICE - TensorView operator+( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) const { - - TensorView result(*this); - result.add_pointer_offset(this->offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorView& operator+=( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) { - - this->add_pointer_offset(this->offset(b)); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorView operator-( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) const { - - TensorRef result(*this); - result.add_pointer_offset(-this->offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorView& operator-=( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) { - - this->add_pointer_offset(-this->offset(b)); - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Constructs a TensorRef, deducing types from arguments. -template < - typename Element, - typename Layout -> -CUTLASS_HOST_DEVICE TensorView make_TensorView( - Element *ptr, - Layout const &layout, - typename Layout::TensorCoord const &extent) { - - return TensorView(ptr, layout, extent); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h deleted file mode 100644 index 6b8f7b47c49d75f0b000d134031ea169fcc6d2a6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h +++ /dev/null @@ -1,302 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a structure containing strides and a pointer to tensor data. - - TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, - it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from - data storage and is therefore lightweight and may be embedded in larger tensor objects or - memory structures. - - See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to - linear memory. -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref_planar_complex.h" -#include "cutlass/tensor_view.h" // cutlass::TensorView - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Data type of element stored within tensor - typename Element_, - /// Maps a Coord in the logical tensor index space to the internal n-D array - typename Layout_ -> -class TensorViewPlanarComplex : public TensorRefPlanarComplex { - public: - - /// Base tensor reference - using Base = cutlass::TensorRefPlanarComplex; - - /// Mapping function from logical coordinate to internal n-D array - using Layout = Layout_; - - /// TensorRef pointing to constant memory - using ConstTensorRef = typename Base::ConstTensorRef; - - /// Underlying TensorRef type - using TensorRef = Base; - - /// Data type of individual access - using Element = Element_; - - /// Reference type to an element - using Reference = Element &; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Coordinate in storage n-D array - using Stride = typename Layout::Stride; - - /// TensorView pointing to constant memory - using ConstTensorView = TensorViewPlanarComplex< - typename platform::remove_const::type const, - Layout>; - - /// TensorView pointing to non-constant memory - using NonConstTensorView = TensorViewPlanarComplex< - typename platform::remove_const::type, - Layout>; - - /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a - /// scalar, but degenerate cases such as these are difficult to accommodate without - /// extensive C++ metaprogramming or support for zero-length arrays. - static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); - - private: - - /// View extent - TensorCoord extent_; - - public: - - // - // Methods - // - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { - - } - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex( - Element *ptr, ///< pointer to start of tensor - Layout const &layout, ///< layout object containing stride and mapping function - LongIndex imaginary_stride, ///< stride between real and imaginary part - TensorCoord const &extent ///< size of the view in logical coordinates - ): - Base(ptr, layout, imaginary_stride), extent_(extent) { - - } - - /// Constructs a TensorView object - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex( - TensorRef const &ref, ///< pointer and layout object referencing a tensor - TensorCoord const &extent ///< logical size of tensor - ): - Base(ref), extent_(extent) { - - } - - /// Converting constructor from TensorRef to non-constant data. - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex( - NonConstTensorView const &view ///< TensorView to non-const data - ): - Base(view), extent_(view.extent_) { } - - /// Updates the pointer and layout object - CUTLASS_HOST_DEVICE - void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { - Base::reset(ptr, layout, imaginary_stride); - this->resize(extent_); - } - - /// Changes the size of the view without affecting pointer or layout - CUTLASS_HOST_DEVICE - void resize(TensorCoord extent) { - this->extent_ = extent; - } - - /// Returns the extent of the view (the size along each logical dimension). - CUTLASS_HOST_DEVICE - TensorCoord const& extent() const { return extent_; } - - /// Returns the extent along a particular logical dimension. - CUTLASS_HOST_DEVICE - Index extent(int dim) const { return extent_.at(dim); } - - /// Determines whether a location is within a tensor - CUTLASS_HOST_DEVICE - bool contains(TensorCoord const& coord) const { - CUTLASS_PRAGMA_UNROLL - for (int dim = 0; dim < kRank; ++dim) { - if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { - return false; - } - } - return true; - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - Base ref() const { - return Base(this->data(), this->layout(), this->imaginary_stride()); - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - ConstTensorRef const_ref() const { - return ConstTensorRef(this->data(), this->layout()); - } - - /// Returns a TensorView to const data - CUTLASS_HOST_DEVICE - ConstTensorView const_view() const { - return ConstTensorView(const_ref(), extent_); - } - - /// Returns a Tensor_view given location and size quantities - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex subview( - TensorCoord extent, ///< extent of the resulting view - TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view - ) const { - - TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location)); - result.add_coord_offset(location); - return result; - } - - /// Returns the number of scalar elements needed to store tensor. - CUTLASS_HOST_DEVICE - size_t capacity() const { - return Base::layout().capacity(extent_); - } - - /// Returns a TensorView offset by a given amount - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex operator+( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) const { - - TensorViewPlanarComplex result(*this); - result.add_pointer_offset(this->offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex& operator+=( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) { - - this->add_pointer_offset(this->offset(b)); - return *this; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex operator-( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) const { - - TensorRef result(*this); - result.add_pointer_offset(-this->offset(b)); - return result; - } - - /// Returns a TensorRef offset by a given amount - CUTLASS_HOST_DEVICE - TensorViewPlanarComplex& operator-=( - TensorCoord const& b ///< offset in the logical coordinate space of the tensor - ) { - - this->add_pointer_offset(-this->offset(b)); - return *this; - } - - /// TensorRef to real-valued tensor - CUTLASS_HOST_DEVICE - cutlass::TensorView view_real() const { - return cutlass::TensorView(this->data(), this->layout(), extent_); - } - - /// TensorRef to real-valued tensor - CUTLASS_HOST_DEVICE - cutlass::TensorView view_imag() const { - return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Constructs a TensorRef, deducing types from arguments. -template < - typename Element, - typename Layout -> -CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( - Element *ptr, - Layout const &layout, - typename Layout::LongIndex imaginary_stride, - typename Layout::TensorCoord const &extent) { - - return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h deleted file mode 100644 index 7bc13e177f1d027fbba789367ac3f2ee5b748877..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h +++ /dev/null @@ -1,479 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines a proxy class for storing Tensor Float 32 data type. -*/ -#pragma once - -#if defined(__CUDACC_RTC__) -#include "cutlass/floating_point_nvrtc.h" -#else -#include -#include -#include -#include // std::memcpy -#endif - -#include "cutlass/cutlass.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tensor Float 32 data type -struct alignas(4) tfloat32_t { - - // - // Data members - // - - /// Storage type - uint32_t storage; - - // - // Methods - // - private: - CUTLASS_HOST_DEVICE - static uint32_t float_to_storage(float s) { - #if defined(__CUDA_ARCH__) - uint32_t result = reinterpret_cast(s); - #else - uint32_t result; - std::memcpy(&result, &s, sizeof(float)); - #endif - return result; - } - - public: - /// Constructs from an unsigned int - CUTLASS_HOST_DEVICE - static tfloat32_t bitcast(uint32_t x) { - tfloat32_t h; - h.storage = x; - return h; - } - - /// Emulated rounding is fast in device code - CUTLASS_HOST_DEVICE - static tfloat32_t round_half_ulp_truncate(float const &s) { - uint32_t x = float_to_storage(s); - - #if defined(__CUDA_ARCH__) - if (::isfinite(s)) { - x += 0x1000u; - } - #else - if (std::isfinite(s)) { - x += 0x1000u; - } - #endif - - return tfloat32_t::bitcast(x); - } - - tfloat32_t() = default; - - /// Floating-point conversion - round toward nearest even - CUTLASS_HOST_DEVICE - explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { } - - // Conversion from double (this rounds twice) - CUTLASS_HOST_DEVICE - explicit tfloat32_t(double x): tfloat32_t(float(x)) { } - - /// Integer conversion - round toward zero - CUTLASS_HOST_DEVICE - explicit tfloat32_t(int x) { - float flt = static_cast(x); - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(flt); - #else - std::memcpy(&storage, &flt, sizeof(storage)); - #endif - } - - // Conversion to float - CUTLASS_HOST_DEVICE - operator float() const { - - // Conversions to IEEE single-precision requires clearing dont-care bits - // of the mantissa. - unsigned bits = (storage & ~0x1fffu); - - #if defined(__CUDA_ARCH__) - return reinterpret_cast(bits); - #else - float flt; - std::memcpy(&flt, &bits, sizeof(flt)); - return flt; - #endif - } - - /// Converts to double - CUTLASS_HOST_DEVICE - explicit operator double() const { - return double(float(*this)); - } - - /// Converts to int - CUTLASS_HOST_DEVICE - explicit operator int() const { - return int(float(*this)); - } - - /// Casts to bool - CUTLASS_HOST_DEVICE - explicit operator bool() const { - return (float(*this) != 0.0f); - } - - /// Obtains raw bits - CUTLASS_HOST_DEVICE - uint32_t raw() const { - return storage; - } - - /// Returns the sign bit - CUTLASS_HOST_DEVICE - bool signbit() const { - return ((raw() & 0x80000000) != 0); - } - - /// Returns the biased exponent - CUTLASS_HOST_DEVICE - int exponent_biased() const { - return int((raw() >> 23) & 0x0ff); - } - - /// Returns the unbiased exponent - CUTLASS_HOST_DEVICE - int exponent() const { - return exponent_biased() - 127; - } - - /// Returns the mantissa - CUTLASS_HOST_DEVICE - int mantissa() const { - return int(raw() & 0x7fffff); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool signbit(cutlass::tfloat32_t const& h) { - return h.signbit(); -} - -CUTLASS_HOST_DEVICE -cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { - return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); -} - -CUTLASS_HOST_DEVICE -bool isnan(cutlass::tfloat32_t const& h) { - return (h.exponent_biased() == 0x0ff) && h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isfinite(cutlass::tfloat32_t const& h) { - return (h.exponent_biased() != 0x0ff); -} - -CUTLASS_HOST_DEVICE -cutlass::tfloat32_t nan_tf32(const char*) { - // NVIDIA canonical NaN - return cutlass::tfloat32_t::bitcast(0x7fffffff); -} - -CUTLASS_HOST_DEVICE -bool isinf(cutlass::tfloat32_t const& h) { - return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -} - -CUTLASS_HOST_DEVICE -bool isnormal(cutlass::tfloat32_t const& h) { - return h.exponent_biased() && h.exponent_biased() != 0x0ff; -} - -CUTLASS_HOST_DEVICE -int fpclassify(cutlass::tfloat32_t const& h) { - int exp = h.exponent_biased(); - int mantissa = h.mantissa(); - if (exp == 0x0ff) { - if (mantissa) { - return FP_NAN; - } - else { - return FP_INFINITE; - } - } - else if (!exp) { - if (mantissa) { - return FP_SUBNORMAL; - } - else { - return FP_ZERO; - } - } - return FP_NORMAL; -} - -CUTLASS_HOST_DEVICE -cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { -#if defined(__CUDACC_RTC__) - return cutlass::tfloat32_t(sqrtf(float(h))); -#else - return cutlass::tfloat32_t(std::sqrt(float(h))); -#endif -} - -CUTLASS_HOST_DEVICE -tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { - - uint32_t a_mag = (a.raw() & 0x7fffffff); - uint32_t b_sign = (b.raw() & 0x80000000); - uint32_t result = (a_mag | b_sign); - - return tfloat32_t::bitcast(result); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Standard Library operations and definitions -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace std { - -#if !defined(__CUDACC_RTC__) -/// Numeric limits -template <> -struct numeric_limits { - static bool const is_specialized = true; - static bool const is_signed = true; - static bool const is_integer = false; - static bool const is_exact = false; - static bool const has_infinity = true; - static bool const has_quiet_NaN = true; - static bool const has_signaling_NaN = false; - static std::float_denorm_style const has_denorm = std::denorm_present; - static bool const has_denorm_loss = true; - static std::float_round_style const round_style = std::round_to_nearest; - static bool const is_iec559 = false; - static bool const is_bounded = true; - static bool const is_modulo = false; - static int const digits = 19; - - /// Least positive value - static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } - - /// Minimum finite value - static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } - - /// Maximum finite value - static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } - - /// Returns smallest finite value - static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } - - /// Returns smallest finite value - static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } - - /// Returns smallest finite value - static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } - - /// Returns smallest finite value - static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } - - /// Returns smallest finite value - static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } - - /// Returns smallest finite value - static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } -}; -#endif - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace std - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Arithmetic operators -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_HOST_DEVICE -bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) == float(rhs); -} - -CUTLASS_HOST_DEVICE -bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) != float(rhs); -} - -CUTLASS_HOST_DEVICE -bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) < float(rhs); -} - -CUTLASS_HOST_DEVICE -bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) <= float(rhs); -} - -CUTLASS_HOST_DEVICE -bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) > float(rhs); -} - -CUTLASS_HOST_DEVICE -bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return float(lhs) >= float(rhs); -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return tfloat32_t(float(lhs) + float(rhs)); -} - - -CUTLASS_HOST_DEVICE -tfloat32_t operator-(tfloat32_t const& lhs) { - return tfloat32_t::bitcast(0x80000000 ^ lhs.raw()); -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return tfloat32_t(float(lhs) - float(rhs)); -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return tfloat32_t(float(lhs) * float(rhs)); -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { - return tfloat32_t(float(lhs) / float(rhs)); -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { - lhs = tfloat32_t(float(lhs) + float(rhs)); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { - lhs = tfloat32_t(float(lhs) - float(rhs)); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { - lhs = tfloat32_t(float(lhs) * float(rhs)); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { - lhs = tfloat32_t(float(lhs) / float(rhs)); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator++(tfloat32_t & lhs) { - float tmp(lhs); - ++tmp; - lhs = tfloat32_t(tmp); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t& operator--(tfloat32_t & lhs) { - float tmp(lhs); - --tmp; - lhs = tfloat32_t(tmp); - return lhs; -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator++(tfloat32_t & lhs, int) { - tfloat32_t ret(lhs); - float tmp(lhs); - tmp++; - lhs = tfloat32_t(tmp); - return ret; -} - -CUTLASS_HOST_DEVICE -tfloat32_t operator--(tfloat32_t & lhs, int) { - tfloat32_t ret(lhs); - float tmp(lhs); - tmp--; - lhs = tfloat32_t(tmp); - return ret; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// User-defined literals -// - -CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(long double x) { - return cutlass::tfloat32_t(float(x)); -} - -CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { - return cutlass::tfloat32_t(int(x)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h deleted file mode 100644 index c338306132b9d9b2e42ff26759f7d1b3a7bc1ae3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h +++ /dev/null @@ -1,198 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines a matrix object intended for storing data in registers and operations within - a CUDA thread. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/matrix_coord.h" - -namespace cutlass { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Per-thread matrix object storing a packed matrix -template < - typename Element, - int Rows, - int Columns, - typename Layout = layout::RowMajor -> -class Matrix : public Array { -public: - - // Verify layout refers to a rank=2 matrix. - static_assert( - Layout::kRank == 2, - "Layout type must refer to a rank=2 matrix"); - - /// Base type - using Base = Array; - - /// Element type - using Element = Element_; - - /// Number of rows - static int const kRows = Rows; - - /// Number of columns - static int const kColumns = Columns; - - /// Layout within the array - using Layout = Layout_; - - /// Reference type to an element - using Reference = Element &; - - /// Logical rank of tensor index space - static int const kRank = 2; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Stride type - using Stride = typename Layout::Stride; - - /// TensorRef to matrix object - using TensorRef = TensorRef; - - /// TensorRef to constant matrix object - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - /// TensorRef to matrix object - using TensorView = TensorView; - - /// TensorRef to constant matrix object - using ConstTensorView = typename TensorView::ConstTensorView; - - /// Diagonal vector - using Diagonal = Vector; - -private: - - -public: - - // - // Methods - // - - /// Returns the size of the object - CUTLASS_HOST_DEVICE - static MatrixCoord extent() { - return make_Coord(kRows, kColumns); - } - - /// Returns the layout object - CUTLASS_HOST_DEVICE - static Layout layout() { - return Layout::packed(extent()); - } - - /// Ctor - CUTLASS_HOST_DEVICE - Matrix() { } - - /// Ctor - CUTLASS_HOST_DEVICE - Matrix(Diagonal const &diag) { - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - TensorRef ref() { - return TensorRef(this->data(), layout()); - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - ConstTensorRef const_ref() const { - return ConstTensorRef(this->data(), layout()); - } - - /// Returns a TensorRef pointing to the first element of the tensor. - CUTLASS_HOST_DEVICE - TensorView view() { - return TensorView(ref(), extent()); - } - - /// Returns a TensorView to const data - CUTLASS_HOST_DEVICE - ConstTensorView const_view() const { - return ConstTensorView(const_ref(), extent()); - } - - /// Returns a reference to the element at a given Coord - CUTLASS_HOST_DEVICE - Reference at(MatrixCoord const& coord) const { - typename Base::size_type offset_(layout().offset(coord)); - return Base::at(offset_); - } - - /// Returns the number of scalar elements needed to store tensor. - CUTLASS_HOST_DEVICE - LongIndex capacity() const { - return LongIndex(Base::size()); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Column vector defined as a matrix with exactly one column -template < - typename Element, - int Rows, - typename Layout = layout::ColumnMajor -> -using ColumnVector = Matrix; - -/// Row vector defined as a matrix with exactly one row -template < - typename Element, - int Columns, - typename Layout = layout::RowMajor -> -using RowVector = Matrix; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h deleted file mode 100644 index 803c72eca35a4cc3ee0712981942016f987f5b44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h +++ /dev/null @@ -1,59 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Helpers for optionally tracing through code when debugging. - - This file is to be included after all other headers. -*/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Tracing options -#ifndef CUTLASS_DEBUG_TRACE_LEVEL -#define CUTLASS_DEBUG_TRACE_LEVEL 0 -#endif - -#if CUTLASS_DEBUG_TRACE_LEVEL -#include -#include "cutlass/core_io.h" -#if defined(__CUDA_ARCH__) -#define CUTLASS_TRACE_HOST(x) -#else -#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } -#endif -#else -#define CUTLASS_TRACE_HOST(x) -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp deleted file mode 100644 index 41bc4786c7a8d148340a23bf1ce1db66f04f10b4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +++ /dev/null @@ -1,754 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing how threads are mapped to a given tile. -*/ - -#pragma once - -#include "cute/arch/mma_sm90_gmma.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { -using namespace cute; - -template -constexpr auto -gmma_smem_transpose_or_passthrough() { - if constexpr (Transpose) { - if constexpr (cute::is_same_v, SmemLayoutAtom>) { - return GMMA::Layout_K_SW128_Atom{}; - } - else if constexpr (cute::is_same_v, SmemLayoutAtom>) { - return GMMA::Layout_K_SW64_Atom{}; - } - else if constexpr (cute::is_same_v, SmemLayoutAtom>) { - return GMMA::Layout_K_SW32_Atom{}; - } - else if constexpr (cute::is_same_v, SmemLayoutAtom>) { - return GMMA::Layout_K_INTER_Atom{}; - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported Layout_SW_Atom for B SMEM transposition"); - } - } - else { - return SmemLayoutAtom{}; - } -} - -template -constexpr auto -use_universal_transposition() { - if constexpr (sizeof(ElementType) == 1) { - return !cute::is_same_v, SmemCopyAtom>; - } - else if constexpr (sizeof(ElementType) == 4){ - // Only universal transposition can handle SW64 and Non swizzle SMEM layout - if constexpr (cute::is_same_v, SmemCopyAtom> || - cute::is_same_v, SmemCopyAtom>) { - return true; - } - else { - return false; - } - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported ElementType for B SMEM transposition"); - } -} - -template< - class TiledMma_, - class SmemLayoutB_, - class SmemLayoutAtomB_, - class ElementB_> -class NoTranspositionOperandB { -public: - using TiledMma = TiledMma_; - using SmemLayoutB = SmemLayoutB_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using ElementB = ElementB_; - - constexpr CUTLASS_HOST_DEVICE - NoTranspositionOperandB( - int, - int, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) { } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void operator()( - TensorSmemB const&, - TensorTransposedSmemB const&, - int, int) { } - - CUTLASS_DEVICE void synchronize(int) { } - - CUTLASS_DEVICE void synchronize() { } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void transpose( - TensorSmemB const&, - TensorTransposedSmemB const&, - int) { } -}; - -template< - class TiledMma_, - class SmemLayoutB_, - class SmemLayoutAtomB_, - class ElementB_> -class UniversalTranspositionOperandB { -public: - using TiledMma = TiledMma_; - using SmemLayoutB = SmemLayoutB_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using ElementB = ElementB_; - - constexpr CUTLASS_HOST_DEVICE - UniversalTranspositionOperandB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) - : warp_idx(warp_idx_) - , warp_group_thread_idx(warp_group_thread_idx_) { } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void operator()( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage, int current_step) { - if (current_step > 0) { - return; - } - - constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static_assert(NumMathWarpGroup == 1 || - (!detail::use_universal_transposition() && NumMathWarpGroup == 2), - "Wrong math warp group number for TransposeB"); - constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. - - constexpr int BytesPerSmemSwizzleUnit = 16; - constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// Universal transposition, need warp_group sync between load and store. - /// The number of reg used depends on the input elementB. - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /* - In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. - In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: - K - ------------ - | W0 W1 W2 W3 --- - | W0 W1 W2 W3 | - | W0 W1 W2 W3 | --> Copy Step 0 - | W0 W1 W2 W3 --- - .... - | W0 W1 W2 W3 --- - | W0 W1 W2 W3 | - | W0 W1 W2 W3 | --> Copy Step n - | W0 W1 W2 W3 --- - */ - static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); - constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); - - // Get copy tile and partition to each thread - auto sB_tiled_copy = make_tiled_copy( - Copy_Atom{}, - WarpgroupThreadLayout, // thr_layout - Layout<_1>{} // val_layout - ); - static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); - - auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); - Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) - Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) - - // Divide partitioned tile to limit register usage - constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); - static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); - - Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); - Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); - auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); - - CUTLASS_PRAGMA_NO_UNROLL - for (int step = 0; step < CopySteps; ++step) { - copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); - - // Make sure all elements are read before being overwritten - __syncthreads(); - - copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); - } - } - - CUTLASS_DEVICE void synchronize(int step) { - if (step == 0) { - // SMEM fence to make sure B is transposed before math - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - } - - CUTLASS_DEVICE void synchronize() { - // SMEM fence to make sure B is transposed before math - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void transpose( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage) { - - this->operator()(sB, gmma_sB, read_stage, 0); - synchronize(); - - } - -private: - const int warp_idx; - const int warp_group_thread_idx; -}; - -template< - class TiledMma_, - class SmemLayoutB_, - class SmemLayoutAtomB_, - class ElementB_> -class AsyncTranspositionOperandB { -public: - - using TiledMma = TiledMma_; - using SmemLayoutB = SmemLayoutB_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using ElementB = ElementB_; - - static constexpr int Steps = 2; - static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; - static_assert(NumMathWarpGroup <= 2, - "Wrong math warp group number for TransposeB"); - static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. - static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; - - static constexpr int BytesPerSmemSwizzleUnit = 16; - static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); - static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; - static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); - - static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); - static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. - static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; - static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; - static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. - static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, - static constexpr int NumBitsPerStep = 3; - static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) - static constexpr int NumBitsPerWarp = 12; - // Number of warp_group_tiles - static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, - "Copy size must evenly divide SMEM tile."); - static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - - static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, - "Need to be able to transpose first k-block in the first step"); - - constexpr CUTLASS_HOST_DEVICE - AsyncTranspositionOperandB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) - : warp_idx(warp_idx_) - , warp_group_thread_idx(warp_group_thread_idx_) - , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) - , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ - % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) - , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ - % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void operator()( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage, int current_step) - { - if (current_step >= StepsPerWarpGroup) { - return; - } - - static constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. - /// In each step, one warp would hold two warp_tiles. - /// Step 0: Step 1: - /// W0 W1 W2 W3 -- -- -- -- - /// W1 W0 -- -- -- -- W3 W2 - /// W2 -- -- -- -- W3 W0 W1 - /// W3 -- -- -- -- W2 W1 W0 - /// - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// - /// Fully static coord LUT to avoid extra register use. - /// [warp_id][step][warp_tile][n / k] - /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 - /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 - /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 - /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 - /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 - /// - /// Encoding the coord of warp tile0 into two int64_t values. - /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. - /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. - /// The 2-step transposition and the 8-step transposition share the same encoding. - /// - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // Divide entire SMEM to multiple warp_tiles - constexpr auto WarpTileShape = make_shape(Int(), Int()); - Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); - Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); - - // Get copy tile - auto sB_tiled_copy = make_tiled_copy( - Copy_Atom{}, - WarpThreadLayout, // thr_layout - Layout<_1>{} // val_layout - ); - - static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); - auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx - - // Construct fragments for transposition - Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); - decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { - make_fragment_like(tmp_tCsB), - make_fragment_like(tmp_tCsB) - }; - - [[maybe_unused]] int step = current_step * NumMathWarpGroup; - if constexpr (NumMathWarpGroup == 2) { - // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. - step += warp_idx / (NumWarpsPerWarpGroup * 2); - } - - int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step); - int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step); - - if constexpr (NumMathWarpGroup == 2) { - tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); - tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); - } - - // decoding the warp tile coord. - int warp_tile0_n, warp_tile0_k; - if constexpr (StepsPerWarpGroup <= NumStepsEncoded) { - warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep; - warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep; - } else { - warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; - warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; - } - - int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; - int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; - - CUTLASS_PRAGMA_UNROLL - for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { - - static_assert(TilesPerWarp == 2); - - // [warp_tile][n/k] - const int warp_tile_coord[TilesPerWarp][2] = { - // n k - {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 - {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 - }; - - CUTLASS_PRAGMA_UNROLL - for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { - Tensor tCsB = sB_thr_copy.partition_S( - flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) - ); // (CPY, CPY_N, CPY_K) - - copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); - } - - // Make sure elements in two 8x8 warp tiles are all consumed - __syncwarp(); - - CUTLASS_PRAGMA_UNROLL - for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { - Tensor tCsB_transposed = sB_thr_copy.partition_D( - flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) - ); // (CPY, CPY_N, CPY_K) - copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); - } - - } // loop warp_group_tile - } - - CUTLASS_DEVICE void synchronize(int step) { - if (step < StepsPerWarpGroup) { - // SMEM fence to make sure B is transposed before math - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - } - - CUTLASS_DEVICE void synchronize() { - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void transpose( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage) { - - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < StepsPerWarpGroup; ++i) { - this->operator()(sB, gmma_sB, read_stage, i); - } - synchronize(); - - } -private: - const int warp_idx; - const int warp_group_thread_idx; - const int warp_idx_in_warp_group; - const int current_warp_tile_n_coord_LUT; - const int current_warp_tile_k_coord_LUT; -}; - -template< - class TiledMma_, - class SmemLayoutB_, - class SmemLayoutAtomB_, - class ElementB_> -class AsyncTranspositionOperandB_1BElementB { -public: - - static_assert(sizeof(ElementB_) == 1); - - using TiledMma = TiledMma_; - using SmemLayoutB = SmemLayoutB_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using ElementB = ElementB_; - - static constexpr int Steps = 8; - static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; - static_assert(NumMathWarpGroup <= 2, - "Wrong math warp group number for TransposeB"); - static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. - static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; - - static constexpr int BytesPerSmemSwizzleUnit = 16; - static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); - static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; - static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); - - static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); - static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. - static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; - static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; - static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. - static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, - static constexpr int NumBitsPerStep = 3; - static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) - static constexpr int NumBitsPerWarp = 12; - // Number of warp_group_tiles - static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, - "Copy size must evenly divide SMEM tile."); - static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - - constexpr CUTLASS_HOST_DEVICE - AsyncTranspositionOperandB_1BElementB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) - : warp_idx(warp_idx_) - , warp_group_thread_idx(warp_group_thread_idx_) - , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) - , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ - % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) - , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ - % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void operator()( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage, int current_step) - { - if (current_step > 0) { - return; - } - - constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. - /// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage. - /// Step 0: Step 1: Step 2: Step 3: - /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2 - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1 - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0 - /// - /// Step 4: Step 5: Step 6: Step 7: - /// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- - /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 - /// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- - /// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- - /// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- - /// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- - /// - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// - /// Fully static coord LUT to avoid extra register use. - /// [warp_id][step][warp_tile][n / k] - /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 - /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 - /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 - /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 - /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 - /// - /// Encoding the coord of warp tile0 into two int64_t values. - /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. - /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. - /// The 2-step transposition and the 8-step transposition share the same encoding. - /// - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // Divide entire SMEM to multiple warp_tiles - constexpr auto WarpTileShape = make_shape(Int(), Int()); - Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); - Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); - - // Get copy tile - auto sB_tiled_copy = make_tiled_copy( - Copy_Atom{}, - WarpThreadLayout, // thr_layout - Layout<_1>{} // val_layout - ); - static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); - auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx - - // Construct fragments for transposition - Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); - decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { - make_fragment_like(tmp_tCsB), - make_fragment_like(tmp_tCsB) - }; - - CUTLASS_PRAGMA_NO_UNROLL - for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { - int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT; - int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT; - constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; - - if constexpr (NumMathWarpGroup == 2) { - tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); - tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) { - // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. - int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2); - // decoding the warp tile coord. - int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; - int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; - int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; - int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; - - tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep; - tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep; - - static_assert(TilesPerWarp == 2); - - // [warp_tile][n/k] - const int warp_tile_coord[TilesPerWarp][2] = { - // n k - {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 - {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 - }; - - CUTLASS_PRAGMA_UNROLL - for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { - Tensor tCsB = sB_thr_copy.partition_S( - flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) - ); // (CPY, CPY_N, CPY_K) - - copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); - } - - // Make sure elements in two 8x8 warp tiles are all consumed - __syncwarp(); - - CUTLASS_PRAGMA_UNROLL - for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { - Tensor tCsB_transposed = sB_thr_copy.partition_D( - flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) - ); // (CPY, CPY_N, CPY_K) - copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); - } - } // lock step - } // loop warp_group_tile - } - - CUTLASS_DEVICE void synchronize(int step) { - if (step == 0) { - // SMEM fence to make sure B is transposed before math - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - } - - CUTLASS_DEVICE void synchronize() { - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); - } - - template < - class TensorSmemB, - class TensorTransposedSmemB> - CUTLASS_DEVICE void transpose( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - int read_stage) { - this->operator()(sB, gmma_sB, read_stage, 0); - synchronize(); - } - -private: - const int warp_idx; - const int warp_group_thread_idx; - const int warp_idx_in_warp_group; - const int current_warp_tile_n_coord_LUT; - const int current_warp_tile_k_coord_LUT; -}; - - -template< - class TiledMma, - class SmemLayoutB, - class SmemLayoutAtomB, - class ElementB, - bool TransposeB -> -constexpr CUTLASS_HOST_DEVICE -auto -make_transpose_operand_b( - int warp_idx, - int warp_group_thread_idx, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB, - cute::bool_constant) -{ - if constexpr (!TransposeB) { - return NoTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, - SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); - } - else if constexpr (use_universal_transposition()) { - return UniversalTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, - SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); - } - else if constexpr (sizeof(ElementB) == 1) { - return AsyncTranspositionOperandB_1BElementB( - warp_idx, warp_group_thread_idx, TiledMma{}, - SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); - } - else { - return AsyncTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, - SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); - } -} - -}; // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace transform -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp deleted file mode 100644 index 265d2fe4367180b0c5c76f22df7d00f01dfb170e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp +++ /dev/null @@ -1,303 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Transform Kernel Universal adapter -*/ - -#pragma once - -// common -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/detail/layout.hpp" -#include "cutlass/detail/mma.hpp" -#include "cutlass/cuda_host_adapter.hpp" - -#include "cutlass/kernel_launch.h" -#if !defined(__CUDACC_RTC__) -#include "cutlass/cluster_launch.hpp" -#include "cutlass/trace.h" -#endif // !defined(__CUDACC_RTC__) - - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::transform::device { - -//////////////////////////////////////////////////////////////////////////////// - -template -class TransformUniversalAdapter -{ -public: - using TransformKernel = GetUnderlyingKernel_t; - using Arguments = typename TransformKernel::Arguments; - using Params = typename TransformKernel::Params; - static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; - - -private: - - /// Kernel API parameters object - Params params_; - -public: - - /// Access the Params structure - Params const& params() const { - return params_; - } - - /// Determines whether the GEMM can execute the given problem. - static Status - can_implement(Arguments const& args) { - return TransformKernel::can_implement(args); - } - - /// Gets the workspace size - static size_t - get_workspace_size(Arguments const& args) { - size_t workspace_bytes = 0; - workspace_bytes += TransformKernel::get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - return workspace_bytes; - } - - /// Computes the grid shape - static dim3 - get_grid_shape(Arguments const& args, void* workspace = nullptr) { - auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace); - return TransformKernel::get_grid_shape(tmp_params); - } - - /// Computes the grid shape - static dim3 - get_grid_shape(Params const& params) { - return TransformKernel::get_grid_shape(params); - } - - - /// Initializes GEMM state from arguments. - Status - initialize( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - - CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null") - << ", EnableCudaHostAdapter: " << (kEnableCudaHostAdapter ? "True" : "false")); - - // Initialize the workspace - Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter); - if (status != Status::kSuccess) { - return status; - } - // Initialize the Params structure - params_ = TransformKernel::to_underlying_arguments(args, workspace); - // Don't set the function attributes - require the CudaHostAdapter to set it. - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - return Status::kSuccess; - } - else { - // - // Account for dynamic smem capacity if needed - // - int smem_size = TransformKernel::SharedStorageSize; - - CUTLASS_ASSERT(cuda_adapter == nullptr); - - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - cudaError_t result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - } - return Status::kSuccess; - } - - static Status - run(Params& params, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr, - int32_t kernel_index = 0, - bool launch_with_pdl = false) { - CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()"); - dim3 const block = TransformKernel::get_block_shape(); - dim3 const grid = get_grid_shape(params); - - // configure smem size and carveout - int smem_size = TransformKernel::SharedStorageSize; - - Status launch_result{ Status::kSuccess }; - // Use extended launch API only for mainloops that use it - if constexpr (TransformKernel::ArchTag::kMinComputeCapability >= 90) { - // Currently only support 1x1x1 for transform kernel. - dim3 const cluster = {1,1,1}; - void* kernel_params[] = {¶ms}; - - if constexpr (kEnableCudaHostAdapter) { - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - - if (launch_with_pdl) { - CUTLASS_TRACE_HOST( - "TransformUniversalAdapter::run() does not support launching with PDL and a custom cuda adapter."); - return Status::kErrorInternal; - } - launch_result = cuda_adapter->launch(grid, - cluster, - block, - smem_size, - stream, - kernel_params, - kernel_index); - CUTLASS_TRACE_HOST("Kernel Launch Result" << cutlassGetStatusString(launch_result)); - } - else { - return Status::kErrorInternal; - } - } - else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - void const* kernel = (void const*) device_kernel; - if constexpr (TransformKernel::ArchTag::kMinComputeCapability == 90) { - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); - } - } - } - else { - launch_result = Status::kSuccess; - cutlass::arch::synclog_setup(); - - if constexpr (kEnableCudaHostAdapter) { - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - void* kernel_params[] = {¶ms}; - - launch_result = cuda_adapter->launch( - grid, block, smem_size, stream, kernel_params, 0 - ); - - } - else { - return Status::kErrorInternal; - } - } - else { - CUTLASS_ASSERT(cuda_adapter == nullptr); - cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); - } - } - - cudaError_t result = cudaGetLastError(); - if (cudaSuccess == result && Status::kSuccess == launch_result) { - return Status::kSuccess; - } - else if (cudaSuccess != result) { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cudaGetErrorString(result)); - } - else if (Status::kSuccess != launch_result) { - CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cutlassGetStatusString(launch_result)); - } - return Status::kErrorInternal; - } - - // - // Non-static launch overloads that first create and set the internal params struct of this kernel handle. - // - - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - run( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr, - int32_t kernel_index = 0, - bool launch_with_pdl = false - ) { - Status status = initialize(args, workspace, stream, cuda_adapter); - - if (Status::kSuccess == status) { - status = run(params_, stream, cuda_adapter, kernel_index, launch_with_pdl); - } - return status; - } - - /// Launches the kernel after first constructing Params internal state from supplied arguments. - Status - operator()( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false) { - return run(args, workspace, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); - } - - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - run( - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr, - bool launch_with_pdl = false) { - return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); - } - - /// Overload that allows a user to re-launch the same kernel without updating internal params struct. - Status - operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { - return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::transform::device - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp deleted file mode 100644 index 9c9d7589a309ebe6276bb564ac76a9e036bdd50a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp +++ /dev/null @@ -1,223 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* \file - \brief Convolution filter format transformation kernel. -*/ - -#pragma once - -#include -#include - -#include "cutlass/coord.h" -#include "cutlass/arch/arch.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/cuda_host_adapter.hpp" - -#include "cute/int_tuple.hpp" -#include "cute/tensor.hpp" -#include "cute/config.hpp" - -namespace cutlass::transform::kernel { - -using namespace cute; - -enum class FilterFormat { - CKTRS, - CTRSK, - KTRSC -}; - -template < - FilterFormat SrcFormat, - FilterFormat DstFormat, - int NumDimensions, - class Element_, - int AlignmentBytes = 16 -> -struct ConvFilterFormatTransformer { - - using Element = Element_; - static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported"); - static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported"); - static_assert(AlignmentBytes > 0 && AlignmentBytes % static_cast(sizeof(Element)) == 0, "Invalid alignment setting"); - - // In ktrsc order. - using FilterExtent = array; - - // Default cta tile shape: 32x32 - static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); - // Default thread layout: (4, 32) - static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{})); - - static constexpr uint32_t MaxThreadsPerBlock = 128; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - using ArchTag = arch::Sm90; - - // Default ctor - CUTLASS_HOST_DEVICE - ConvFilterFormatTransformer() {} - - struct Arguments { - const void *src_ptr; - void *dst_ptr; - FilterExtent filter_extent; - }; - - struct Params { - using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{})))); - using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0))))); - - TensorSrc src; - TensorDst dst; - }; - - struct SharedStorage { - /* empty, no smem needed */ - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - static Status - can_implement(Arguments const& args) { - bool implementable = true; - // alignment rule - { - int contiguous_dim = DstFormat == FilterFormat::CTRSK ? args.filter_extent[0] : args.filter_extent[NumDimensions - 1]; - int align_element = AlignmentBytes / static_cast(sizeof(Element)); - - implementable &= (contiguous_dim % align_element == 0); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Alignment setting is invalid.\n"); - return Status::kInvalid; - } - } - - return Status::kSuccess; - } - - static size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static dim3 - get_block_shape() { - return dim3(size(shape(ThreadLayout)), 1, 1); - } - - static dim3 - get_grid_shape(Params const& params) { - auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape)); - auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape)); - - return dim3(dim_m, dim_n, 1); - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - return Status::kSuccess; - } - - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - auto k = args.filter_extent[0]; - auto c = args.filter_extent[NumDimensions - 1]; - auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent)); - - // source shape (s,r,t,k,c) - auto shape_src = flatten(make_shape(srt, k, c)); - auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt)); - - auto src = make_tensor(make_gmem_ptr(recast_ptr(args.src_ptr)), make_layout(shape_src)); - auto dst = make_tensor(make_gmem_ptr(recast_ptr(args.dst_ptr)), make_layout(shape_dst)); - - return Params{src, dst}; - } - - CUTLASS_DEVICE - void operator()(Params const& params, char *smem_buf) { - // Tile the input tensor into blocks - auto block_coord = make_coord(blockIdx.x, blockIdx.y); - auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); - // Default thread layout: (4, 32) - auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{})); - auto vec_layout = make_layout(make_shape(Int(sizeof(Element))>{}, Int<1>{})); - - Tensor tile_D = local_tile(params.dst, block_shape, block_coord); - - // Construct tiled copy - using AccessType = cutlass::AlignedArray; - using Atom = Copy_Atom, Element>; - - auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout); - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); - Tensor thr_tile_D = thr_copy.partition_D(tile_D); - - // shape (s, r, t) - auto shape_trs = take<0, NumDimensions - 2>(shape(params.src)); - // strided_c = c for format CTRSK, strided_c = k for format KTRSC - auto strided_c = DstFormat == FilterFormat::CTRSK ? get(shape(params.src)) : get(shape(params.src)); - // shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC - auto shape_ctrs = append(shape_trs, strided_c); - auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs); - // index of k for format CTRSK and index of c for format KTRSC - auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout))); - int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout))); - - // Fragment to load from S and store to D - auto frag = make_fragment_like(thr_tile_D); - // Predicate tensor. - Tensor thr_tile_P = make_tensor(shape(thr_tile_D)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(frag); ++i) { - auto srt_coord = take<0, NumDimensions - 2>(srtc_coord); - auto kc_coord = DstFormat == FilterFormat::CTRSK ? - make_coord(n_idx+i, get(srtc_coord)) : - make_coord(get(srtc_coord), n_idx+i); - auto coord = flatten(make_coord(srt_coord, kc_coord)); - thr_tile_P(i) = elem_less(coord, shape(params.src)); - if (thr_tile_P(i)) { - frag(i) = params.src(coord); - } - } - - // Copy from RMEM to GMEM - copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D); - } -}; - -} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp deleted file mode 100644 index 577c68c341c5c7a3d26c7209b2c40e309c65abee..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ /dev/null @@ -1,603 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Compress utils specific for SM90 structure sparse kernels -*/ - -#pragma once - -#include "cute/container/bit_field.hpp" // cute::bit_field -#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v, cute::uint_bit_t -#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor -#include "cute/algorithm/cooperative_copy.hpp" // cute::cooperative_copy -#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 -#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter -#include "cutlass/cutlass.h" // cutlass::Status -#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t -#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up -#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo -#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes -#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v -#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter - -namespace cutlass::transform::kernel { - -using namespace cute; - -template< - class ProblemShape_, - class ElementA_, - class LayoutATag_, - class SparseConfig_ -> -class SM90StructuredSparseCompressor { -public: - using SparseConfig = SparseConfig_; - using ProblemShape = ProblemShape_; - - // * EltA - using ElementA = ElementA_; - using ElementAUint = cute::uint_bit_t>; - using ElementAMma = typename SparseConfig::ElementAMma; - using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; - using ElementAMmaRawUnit = cute::uint_bit_t>; - using ElementASparsity = typename SparseConfig::ElementASparsity; - using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; - using ElementAUintCompressed = cute::sparse_elem; - using LayoutATag = LayoutATag_; - using LayoutA = LayoutATag; - using StrideA = cutlass::gemm::TagToStrideA_t; - - // * EltE - using ElementEMma = typename SparseConfig::ElementEMma; - using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; - using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; - // Data Type for storing one chunk's metadata - static constexpr int ElementEBitsPerChunk = typename SparseConfig::ElementEBitsPerChunk{}; - CUTE_STATIC_ASSERT(ElementEBitsPerChunk == 4, "ElementEBitsPerChunk is 4 for SM90"); - using ElementEChunk = cute::uint_bit_t; - CUTE_STATIC_ASSERT(cute::is_same_v, "ElementEChunk is uint4_t for SM90"); - using ElementESparsityPerChunk = Int / ElementEBitsPerChunk)>; - - // AtomE - using TensorEAtom = typename SparseConfig::TensorEAtom; - using TensorEAtomK = typename SparseConfig::TensorEAtomK; - using TensorEAtomM = typename SparseConfig::TensorEAtomM; - - static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; - static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; - static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; - static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - - // * Alignment - static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; - static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; - static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; - static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; - - // Required by `device_kernel` - static constexpr int MaxThreadsPerBlock = TensorEAtomM{}; - static constexpr int MinBlocksPerMultiprocessor = 1; - using ArchTag = arch::Sm90; - - struct SharedStorage { - ElementEMma cEsE[cute::size(TensorEAtom{})]; - ElementAUintCompressed cACsAC[cute::size(TensorEAtom{})]; - ElementAUint cAsA[cute::size(TensorEAtom{})]; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - struct TransformArguments { - void const* ptr_A{nullptr}; - StrideA dA{}; - void* ptr_ACompress{nullptr}; - void* ptr_E{nullptr}; - }; - - using TransformParams = TransformArguments; - - struct Arguments { - ProblemShape problem_shape{}; - TransformArguments transform{}; - KernelHardwareInfo hw_info{}; - }; - - struct Params { - ProblemShape problem_shape{}; - TransformParams transform{}; - KernelHardwareInfo hw_info{}; - void* workspace = nullptr; - }; - -public: - static Params - to_underlying_arguments(Arguments const& args, void* workspace = nullptr) { - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::to_underlying_arguments()"); - return Params{{args.problem_shape}, - {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, - {args.hw_info}, - workspace}; - } - - static Status - can_implement(Arguments const& args) { - auto [M, N, K, L] = args.problem_shape; - if (K % LogicalElemsAPerChunk != 0) { - CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size"); - return Status::kErrorInvalidProblem; - } - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::can_implement() (True)"); - return Status::kSuccess; - } - - static size_t - get_workspace_size(Arguments const& args) { - CUTLASS_UNUSED(args); - // Backward compatible with host compressor - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_workspace_size() (" << SharedStorageSize << ")"); - return SharedStorageSize; - } - - static Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - CUTLASS_UNUSED(args); - CUTLASS_UNUSED(workspace); - CUTLASS_UNUSED(stream); - CUTLASS_UNUSED(cuda_adapter); - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::initialize_workspace()"); - return Status::kSuccess; - } - - static dim3 - get_grid_shape(Params const& params) { - constexpr int MaxAlignmentM = cutlass::const_max(TensorEAlignmentM, TensorAAlignmentM); - constexpr int MaxAlignmentK = cutlass::const_max(TensorEAlignmentK, TensorAAlignmentK); - const auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; - - const int GemmMAlignedMax = cutlass::round_up(GemmM, MaxAlignmentM); - const int GemmKAlignedMax = cutlass::round_up(GemmK, MaxAlignmentK); - - const int gridDim_X = cutlass::ceil_div(GemmMAlignedMax, TensorEAtomM{}); - const int gridDim_Y = cutlass::ceil_div(GemmKAlignedMax, TensorEAtomK{}); - const int gridDim_Z = GemmL; - - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_grid_shape() (" - << gridDim_X << ", " - << gridDim_Y << ", " - << gridDim_Z << ")"); - return dim3(gridDim_X, gridDim_Y, gridDim_Z); - } - - static dim3 - get_block_shape() { - CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_block_shape() (" - << MaxThreadsPerBlock << ", " - << 1 << ", " - << 1 << ")"); - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTE_DEVICE - void - operator()(Params params, void* smem_buf = nullptr) { - run(params, smem_buf); - } - - CUTE_DEVICE - static void - run(Params params, void* smem_buf = nullptr) { - structure_sparse_compress(params, smem_buf); - } - -private: - - struct MetadataOneChunk1to2 { - - CUTE_DEVICE - void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { - auto metadata_bits = [&]() -> uint8_t { - CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 2); - switch (elt_log_idx) { - case 0: - return 0b0100; - case 1: - return 0b1110; - default: - CUTE_GCC_UNREACHABLE; - } - }; - - storage_ |= (metadata_bits() << (4 * elt_phy_idx)); - } - - - CUTE_DEVICE - ElementEChunk storage() const { - return ElementEChunk{storage_}; - } - - private: - uint8_t storage_ = 0b0000; - }; - - struct MetadataOneChunk2to4{ - - CUTE_DEVICE - void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { - auto metadata_bits = [&]() -> uint8_t { - CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 4); - switch (elt_log_idx) { - case 0: - return 0b00; - case 1: - return 0b01; - case 2: - return 0b10; - case 3: - return 0b11; - default: - CUTLASS_ASSERT(false); - CUTE_GCC_UNREACHABLE; - return 0b00; - } - }; - - storage_ |= (metadata_bits() << (2 * elt_phy_idx)); - } - - CUTE_DEVICE - ElementEChunk storage() const { - return ElementEChunk{storage_}; - } - - private: - uint8_t storage_ = 0b0000; - }; - - using MetadataOneChunk = cute::conditional_t; - -private: - - CUTE_DEVICE - static void - structure_sparse_compress(Params params, void* smem_buf) { - // * Input Params - auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; - auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - [[maybe_unused]] const int gridDim_X = gridDim.x; - [[maybe_unused]] const int gridDim_Y = gridDim.y; - [[maybe_unused]] const int gridDim_Z = gridDim.z; - [[maybe_unused]] const int blockDim_X = blockDim.x; - - // * Global Tensor Layout - const cute::Layout layout_gA = make_layout(make_shape(GemmM, GemmK, GemmL), dA); - const cute::Layout layout_gAC = SparseConfig::fill_layoutA(params.problem_shape); - const cute::Layout layout_gE = SparseConfig::fill_layoutE(params.problem_shape); - - // * Construct Global Tensor - const cute::Tensor gA = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_A)), layout_gA); - cute::Tensor gAC_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_ACompress)), layout_gAC ); - cute::Tensor gAC = cute::recast(gAC_sparse); - cute::Tensor gE_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_E)), layout_gE); - cute::Tensor gE = cute::recast(gE_sparse); - - // * CTA Tensor Layout - using cAsA_layout_row = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutRight{})); - using cAsA_layout_col = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutLeft{})); - using cAsA_layout = cute::conditional_t, cAsA_layout_row, cAsA_layout_col>; - using cACsAC_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementASparsity{}), LayoutRight{})); - using cEsE_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementEMmaSparsity{}), LayoutRight{})); - - CUTE_STATIC_ASSERT(cute::is_static_v, "TensorEAtom needs to be static"); - CUTE_STATIC_ASSERT(cute::is_static_v, "cAsA_layout needs to be static"); - CUTE_STATIC_ASSERT(cute::is_static_v, "cACsAC_layout needs to be static"); - CUTE_STATIC_ASSERT(cute::is_static_v, "cEsE_layout needs to be static"); - - const int blockIdx_X = blockIdx.x; - const int blockIdx_Y = blockIdx.y; - const int blockIdx_Z = blockIdx.z; - const int threadIdx_X = threadIdx.x; - - // * Construct CTA Tensor - const auto cta_coord = make_coord(blockIdx_X, blockIdx_Y, blockIdx_Z); - cute::Tensor cAgA = cute::recast(local_tile(gA, shape(cAsA_layout{}), cta_coord)); - cute::Tensor cACgAC = cute::recast(local_tile(gAC, shape(cACsAC_layout{}), cta_coord)); - cute::Tensor cEgE = local_tile(gE, shape(cEsE_layout{}), cta_coord); - - cute::Tensor cAsA = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cAsA)), cAsA_layout{})); - cute::Tensor cACsAC = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cACsAC)), cACsAC_layout{})); - cute::Tensor cEsE = make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cEsE)), cEsE_layout{}); - cute::Tensor cEsE_chunk = cute::recast(cEsE); - - // * Handle in unit of Chunk when compress - using OneChunkSizeA = Int; - using OneChunkSizeAC = Int; - using OneChunkSizeE = Int; - using NumOneChunkK = Int; - - cute::Tensor cAsA_log_chunk = logical_divide(cAsA, make_shape(_, OneChunkSizeA{})); - cute::Tensor cACsAC_log_chunk = logical_divide(cACsAC, make_shape(_, OneChunkSizeAC{})); - cute::Tensor cEsE_log_chunk = logical_divide(cEsE_chunk, make_shape(_, OneChunkSizeE{})); - - // * Corner Case Handle - const auto GemmM_within_Cta = (GemmM - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmM - blockIdx_X * TensorEAtomM{}; - const auto GemmK_within_Cta = ( (GemmK - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmK - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; - const auto GemmK_NumOneChunk_within_Cta = GemmK_within_Cta / LogicalElemsAMmaRawPerChunk; - - const auto GemmMAlignedAC = cutlass::round_up(GemmM, TensorAAlignmentM); - const auto GemmKAlignedAC = cutlass::round_up(GemmK, TensorAAlignmentK); - const auto GemmMAlignedAC_within_Cta = (GemmMAlignedAC - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmMAlignedAC - blockIdx_X * TensorEAtomM{}; - const auto GemmKAlignedAC_within_Cta = ( (GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; - - // * Clear CTA Smem Tensor - cooperative_clear(threadIdx_X, cACsAC); - cooperative_clear(threadIdx_X, cEsE); - - // * Input CTA Tensor G to S - if (GemmM_within_Cta == TensorEAtomM{} && GemmK_within_Cta == TensorEAtomK{}) { - copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); - } - else { - copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); - } - - // Construct a sign bit mask for handling negative zeros - ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 }; - if constexpr (has_negative_zero_v) { - ElementAMmaRawUnit one_sign_mask = static_cast(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v - 1))); - for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) { - sign_mask = static_cast((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v)); - } - } - - // * Compress - // cACsAC is always row major order - // TensorEAtomM threads perform the compression, each thread compress one row - const int row_i = threadIdx_X; - if (row_i < GemmM_within_Cta) { - - CUTE_UNROLL - for (int col_chunk_i = 0; col_chunk_i < NumOneChunkK{}; ++col_chunk_i) { - if (col_chunk_i < GemmK_NumOneChunk_within_Cta) { - // Compress is handled in unit of ElementAMmaRawUnit - cute::Tensor tAsA = cAsA_log_chunk(row_i, make_coord(_, col_chunk_i)); - cute::Tensor tACsAC = cACsAC_log_chunk(row_i, make_coord(_, col_chunk_i)); - cute::Tensor tEsE = cEsE_log_chunk(row_i, make_coord(_, col_chunk_i)); - - int non_zero_cnt = 0; - // None zero element indx - // e.g. - // 2:4 sparsity [x 0 0 x] - // non_zero_elt_log_idx = [0, 3] - int non_zero_elt_log_idx[OneChunkSizeAC{}] = { 0 }; - - // * Find None Zero Element Idx within Chunk - CUTE_UNROLL - for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) { - ElementAMmaRawUnit elem_A = tAsA[elt_log_idx]; - - // Handle negative 0 - ElementAMmaRawUnit masked_elem_A = elem_A; - if constexpr (has_negative_zero_v) { - masked_elem_A = elem_A & sign_mask; - } - - if (masked_elem_A != ElementAMmaRawUnit{0}) { - non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx; - tACsAC[non_zero_cnt] = elem_A; - non_zero_cnt++; - } - } - - // * Corner Case for 2:4 sparsity - if constexpr (cute::sizeof_bits_v < 32) { - // i.e. [0 0 0 x] -> [(0) 0 0 x] - if (non_zero_cnt == 1 && non_zero_elt_log_idx[0] == 3) { - tACsAC[1] = tACsAC[0]; - tACsAC[0] = ElementAMmaRawUnit{0}; - non_zero_elt_log_idx[0] = 0; - non_zero_elt_log_idx[1] = 3; - } - // i.e. [0 0 x 0] -> [0 0 x (0)] - // i.e. [0 x 0 0] -> [0 x 0 (0)] - // i.e. [x 0 0 0] -> [x 0 0 (0)] - else if (non_zero_cnt == 1) { - tACsAC[1] = ElementAMmaRawUnit{0}; - non_zero_elt_log_idx[1] = 3; - } - } - - // * Set Metadata Bits - MetadataOneChunk metadata_one_chunk; - CUTE_UNROLL - for (int elt_phy_idx = 0; elt_phy_idx < OneChunkSizeAC{}; elt_phy_idx++) { - metadata_one_chunk.set_metadata_bits(non_zero_elt_log_idx[elt_phy_idx], elt_phy_idx); - } - tEsE[0] = metadata_one_chunk.storage(); - - } - else { - break; - } - } - } - - // * Sync after Compress - __syncthreads(); - - // * Output Cta Tensor S to G - if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) { - constexpr int MaxVecBits = 128; // STG.128 - cute::cooperative_copy(threadIdx_X, cEsE, cEgE); - } - - if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) { - copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); - } - else { - copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); - } - - } // end of structure_sparse_compress() - - template - CUTE_DEVICE - static void - cooperative_clear( - uint32_t const& tid, - TensorSrc dSrc) { - - auto dSrctSrc = local_partition(dSrc, make_layout(make_shape(NumThreads, _1{})), tid); - cute::clear(dSrctSrc); - - // Sync all thread data access - __syncthreads(); - } - - template - CUTE_DEVICE - static void - copy_vec_pred( - TensorSrc dSrc, - TensorDst dDst, - int threadIdx_X, - int valid_rows, - int valid_cols) { - - constexpr bool IsRowMajor = cute::is_same_v; - using Element = typename TensorSrc::element_type; - constexpr bool IsQmmaF6 = cute::sizeof_bits_v == 6; - - CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dSrc) needs to be static"); - CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dDst) needs to be static"); - CUTE_STATIC_ASSERT(cute::sizeof_bits_v == cute::sizeof_bits_v, - "dSrc and dDst need to have same element bit width"); - CUTE_STATIC_ASSERT(cute::size(dSrc) == cute::size(dDst), "dSrc and dDst need to have same size"); - - // ValueShape - using ValueShape = - cute::conditional_t, Int<1>>, - cute::conditional_t, Int<128 / sizeof_bits_v>>, - Shape>, Int<1>>> - >; - - constexpr int ValueShapeRows = shape<0>(ValueShape{}); - constexpr int ValueShapeCols = shape<1>(ValueShape{}); - - // ThreadShape - using ThreadShape = - cute::conditional_t, Int<1>>, - Shape, Int>>, - cute::conditional_t(dSrc) / ValueShapeCols)>, Int< (shape<1>(dSrc) / ValueShapeCols)>>, - Shape(dSrc) / ValueShapeRows)>, Int(dSrc) / ValueShapeRows)>>> - >; - - constexpr int ThreadShapeRows = shape<0>(ThreadShape{}); - constexpr int ThreadShapeCols = shape<1>(ThreadShape{}); - - const int threadIdx_X_row = threadIdx_X / ThreadShapeCols; - const int threadIdx_X_col = threadIdx_X % ThreadShapeCols; - - // Row Major - if constexpr (IsRowMajor) { - CUTE_UNROLL - for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { - CUTE_UNROLL - for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { - CUTE_UNROLL - for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { - CUTE_UNROLL - for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { - const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; - const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; - if constexpr ( (not pred) and (not IsQmmaF6) ) { - dDst(row_i, col_i) = dSrc(row_i, col_i); - } - else { - if (row_i < valid_rows && col_i < valid_cols) { - dDst(row_i, col_i) = dSrc(row_i, col_i); - } - } - } - } - } - } - } - // Col Major - else { - CUTE_UNROLL - for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { - CUTE_UNROLL - for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { - CUTE_UNROLL - for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { - CUTE_UNROLL - for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { - const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; - const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; - if constexpr ( (not pred) and (not IsQmmaF6) ) { - dDst(row_i, col_i) = dSrc(row_i, col_i); - } - else { - if (row_i < valid_rows && col_i < valid_cols) { - dDst(row_i, col_i) = dSrc(row_i, col_i); - } - } - } - } - } - } - } - - // Sync all thread data access - __syncthreads(); - } // end of copy_vec_pred() - -}; - -} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp deleted file mode 100644 index 9f23535fea5df8df728b7c806d65f75f28c36aa3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp +++ /dev/null @@ -1,325 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Compress utils for structured sparse kernels -*/ - -#pragma once - -#include // std::fill -#include // std::array -#include // std::mt19937 - -#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v -#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor -#include "cutlass/arch/arch.h" // cutlass::arch::SmXY -#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false -#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t -#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up -#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes - -#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp" - -namespace cutlass::transform::kernel { - -template< - class ProblemShape_, - class ElementA_, - class LayoutATag_, - class SparseConfig_ -> -class StructuredSparseCompressorUtility { -public: - using SparseConfig = SparseConfig_; - using ProblemShape = ProblemShape_; - - //* EltA - using ElementA = ElementA_; - using LayoutATag = LayoutATag_; - using StrideA = cutlass::gemm::TagToStrideA_t; - using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; - using ElementASparsity = typename SparseConfig::ElementASparsity; - using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; - - //* EltE - using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; - using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; - - //* AtomE - using TensorEAtom = typename SparseConfig::TensorEAtom; - using TensorEAtomK = typename SparseConfig::TensorEAtomK; - using TensorEAtomM = typename SparseConfig::TensorEAtomM; - - static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; - static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; - static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; - static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - - //* Alignment - static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; - static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; - static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; - static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; - - StructuredSparseCompressorUtility() = default; - - StructuredSparseCompressorUtility(ProblemShape problem, StrideA dA) { - set_problem_size(problem, dA); - } - - void set_problem_size(ProblemShape problem, StrideA dA_) { - M = cute::size<0>(problem); - K = cute::size<2>(problem); - L = cute::size<3>(problem); - - // The following three vars are logical elem count! - K_alignedA = round_up(K, TensorAAlignmentK); - M_alignedA = round_up(M, TensorAAlignmentM); - K_alignedE = round_up(K, TensorEAlignmentK); - M_alignedE = round_up(M, TensorEAlignmentM); - - dA = dA_; - } - - /** - * @brief Get the TensorE number of ElementE along K after alignment requirement - * - * @return int : number of ElementE (uint8_t) along K-dim - */ - int get_metadata_m_physical() const { - return M_alignedE; - } - - /** - * @brief Get the TensorE number of ElementE along M after alignment requirement - * - * @return int : number of ElementE (uint8_t) along M-dim - */ - int get_metadata_k_physical() const { - return K_alignedE / ElementEMmaSparsity{}; - } - - /** - * @brief Get the TensorACompressed number of ElementA along K after alignment requirement - * - * @return int : number of ElementA along K-dim - */ - int get_tensorA_k_physical() const { - return K_alignedA / ElementASparsity{}; - } - - /** - * @brief Get the TensorACompressed number of ElementA along M after alignment requirement - * - * @return int : number of ElementA along M-dim - */ - int get_tensorA_m_physical() const { - return M_alignedA; - } - - /** - * @brief Get the TensorACompressed Bytes - * - * @return uint64_t bytes - */ - uint64_t get_compressed_tensor_A_bytes() const { - const auto tensor_a_comp_num_elt_a = get_tensorA_m_physical() * get_tensorA_k_physical() * L; - const auto tensor_a_comp_bytes = cutlass::bits_to_bytes(tensor_a_comp_num_elt_a * cute::sizeof_bits_v); - return tensor_a_comp_bytes; - } - - /** - * @brief Get the TensorA Bytes - * - * @return uint64_t bytes - */ - uint64_t get_raw_tensor_A_bytes() const { - const auto tensor_a_num_elt_a = uint64_t(M) * uint64_t(K) * uint64_t(L); - const auto tensor_a_bytes = cutlass::bits_to_bytes(tensor_a_num_elt_a * cute::sizeof_bits_v); - return tensor_a_bytes; - } - - /** - * @brief Get the TensorE Bytes - * - * @return uint64_t bytes - */ - uint64_t get_tensor_E_bytes() const { - const auto tensor_e_num_elt_a = uint64_t(get_metadata_m_physical()) * uint64_t(get_metadata_k_physical()) * uint64_t(L); - const auto tensor_e_bytes = cutlass::bits_to_bytes(tensor_e_num_elt_a * cute::sizeof_bits_v); - return tensor_e_bytes; - } - - constexpr auto fill_layoutA_from_compressor() const { - return SparseConfig::fill_layoutA(cute::make_tuple(M,_1{},K,L)); - } - - constexpr auto fill_layoutE_from_compressor() const { - return SparseConfig::fill_layoutE(cute::make_tuple(M,_1{},K,L)); - } - - void structure_sparse_zero_mask_fill(void* host_a_ptr, uint64_t seed) { - - constexpr int ChunkSize = LogicalElemsAMmaRawPerChunk; - using ChunkElement = cute::uint_bit_t>; - - cute::Tensor gA_eltA = cute::make_tensor( - cute::recast_ptr(host_a_ptr), - cute::make_layout(make_shape(M, K, L), dA)); - - // Input TensorA is handled in unit of ElementAMmaRaw instead of ElementA - cute::Tensor gA = cute::recast(gA_eltA); - - // Extract out the Chunk from K-mode - Tensor gA_chunk = cute::zipped_divide(gA, cute::Shape<_1,cute::Int>{}); // (Chunk, Rest) - - // Half of the data is zero to indicate sparsityA = 2 - std::array nnzb_indicator{}; - for (size_t i = 1; i < nnzb_indicator.size(); i += 2) { - nnzb_indicator.at(i) = 1; - } - - std::mt19937 rng(seed); - auto rest_shape = cute::shape<1>(gA_chunk); - for (auto iter = cute::make_coord_iterator(rest_shape); iter != cute::ForwardCoordIteratorSentinel{}; ++iter) { - std::shuffle(nnzb_indicator.begin(), nnzb_indicator.end(), rng); - for (int c = 0; c < size<0>(gA_chunk); ++c) { // for each elem within chunk - if (nnzb_indicator[c] == 0) { - gA_chunk(c, *iter) = ChunkElement{0}; - } - } // end of within chunk - } // end of chunk_idx - } - - int M{-1}; - int K{-1}; - int L{-1}; - StrideA dA{}; - -private: - int K_alignedA{-1}; - int M_alignedA{-1}; - int K_alignedE{-1}; - int M_alignedE{-1}; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template< - class ProblemShape, - class ElementA, - class LayoutATag, - class SparseConfig, - class ArchTag -> -struct StructuredSparseCompressorSelector { - static_assert(cutlass::detail::dependent_false, - "Could not select a structured sparse compressor for given parameters."); -}; - -template< - class ProblemShape, - class ElementA, - class LayoutATag, - class SparseConfig -> -struct StructuredSparseCompressorSelector< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig, - arch::Sm90> { - using Compressor = SM90StructuredSparseCompressor< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig - >; -}; - -template< - class ProblemShape, - class ElementA, - class LayoutATag, - class SparseConfig -> -struct StructuredSparseCompressorSelector< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig, - arch::Sm100> { - using Compressor = SM90StructuredSparseCompressor< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig - >; -}; - -template< - class ProblemShape, - class ElementA, - class LayoutATag, - class SparseConfig -> -struct StructuredSparseCompressorSelector< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig, - arch::Sm120> { - using Compressor = SM90StructuredSparseCompressor< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig - >; -}; - -template< - class ProblemShape, - class ElementA, - class LayoutATag, - class SparseConfig, - class ArchTag -> -using StructuredSparseCompressor = typename StructuredSparseCompressorSelector< - ProblemShape, - ElementA, - LayoutATag, - SparseConfig, - ArchTag ->::Compressor; - -} // End namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h deleted file mode 100644 index ef553aab2043775758c2a87d422456dc5cca2426..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h +++ /dev/null @@ -1,926 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing how threads are mapped to a given tile. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/layout/pitch_linear.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { - -//////////////////////////////////////////////////////////////////////////////// - -/// Strip-mines a pitch-linear tile among a given number of threads, first along -/// the contiguous dimension then along the strided dimension. -/// -/// The tile must be divisible by the thread count such that all threads may -/// execute the same number of iterations with the same delta to exhaustively -/// cover the tile. -/// -/// This class satisfies the "RegularThreadMapping" concept. -/// -/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor -/// kernels. -template < - typename Shape_, - int Threads, - int ElementsPerAccess = 1 -> -struct PitchLinearStripminedThreadMap { - - /// Tensor coordinate - using TensorCoord = layout::PitchLinearCoord; - - /// Tile shape - using Shape = Shape_; - - /// Number of threads total - static int const kThreads = Threads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ElementsPerAccess; - - /// Shape of access by each thread - using ThreadAccessShape = layout::PitchLinearShape; - - /// Internal implementation details - struct Detail { - - static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); - - /// Shape of the tile in units of vectors - using ShapeVec = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess, - Shape::kStrided - >; - - static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || - (!(kThreads % ShapeVec::kContiguous)), - "Shape must be divisible by number of iterations of each thread."); - }; - - /// Number of iterations by each thread - using Iterations = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - 1, - // Redo the comparison here to work around divide by zero compiler - // error. The compiler evaluates both path of platform::conditional. - (Threads >= Detail::ShapeVec::kContiguous - ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / - (kThreads / Detail::ShapeVec::kContiguous) - : 0)>, - layout::PitchLinearShape>::type; - - - /// Interval between accesses along each dimension of the tensor's logical coordinate space - /// (in units of Elements) - using Delta = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - 1, - kThreads / Detail::ShapeVec::kContiguous - >, - layout::PitchLinearShape< - kThreads * kElementsPerAccess, - 1 - > - >::type; - - /// Shape of the tile in units of vectors - using StorageShape = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape, - layout::PitchLinearShape>::type; - - /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space - /// (in units of Elements) - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - return TensorCoord( - (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, - thread_id / Detail::ShapeVec::kContiguous); - } -}; - -/// This ThreadMap is used by GEMV -template < - typename Shape, - int Threads, - int ElementsPerAccess = 1 -> -struct PitchLinearTilePolicyStripminedThreadContiguous -{ - static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0, - "Contiguous shape must divide number of threads"); - - using TensorCoord = layout::PitchLinearCoord; - - static int const kThreads = Threads; - static int const kElementsPerAccess = ElementsPerAccess; - - using Iterations = layout::PitchLinearShape< - Shape::kContiguous / (kThreads * kElementsPerAccess), - Shape::kStrided>; - - using Delta = layout::PitchLinearShape<1, 1>; - - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) - { - return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0); - } -}; - -template < - typename Shape, - int Threads, - int ElementsPerAccess = 1 -> -struct PitchLinearTilePolicyStripminedThreadStrided -{ - static_assert((Shape::kStrided % Threads == 0), - "Strided shape must divide number of threads"); - - using TensorCoord = layout::PitchLinearCoord; - - static int const kThreads = Threads; - static int const kElementsPerAccess = ElementsPerAccess; - - using Iterations = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess, - Shape::kStrided / kThreads>; - - using Delta = layout::PitchLinearShape<1, 1>; - - using ShapeVec = Shape; - - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) - { - - return TensorCoord(0, thread_id * Iterations::kStrided); - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -/// elements. -/// -/// This ThreadMap is used by tensor core kernels. -template < - typename Shape_, - int Threads, - typename WarpThreadArrangement_, - int ElementsPerAccess = 1 -> -struct PitchLinearWarpRakedThreadMap { - - /// Tensor coordinate - using TensorCoord = layout::PitchLinearCoord; - - /// Tile shape - using Shape = Shape_; - - /// Number of threads total - static int const kThreads = Threads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ElementsPerAccess; - - /// Shape of access by each thread - using ThreadAccessShape = layout::PitchLinearShape; - - /// Internal details made public to facilitate introspection - struct Detail { - - /// Fixed arrangement of threads within a warp (units of threads). - using WarpThreadArrangement = WarpThreadArrangement_; - - /// Number of threads per warp - static int const kWarpSize = WarpThreadArrangement::kCount; - - /// Number of participating warps - static int const kWarpCount = kThreads / kWarpSize; - - static_assert( - !(Shape::kContiguous % kElementsPerAccess), - "Shape must be divisible by vector length."); - - /// Compute the 'shape' of the overall tile in units of vectors - using ShapeInAccesses = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess, - Shape::kStrided - >; - - static_assert( - !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), - "ShapeInAccesses must be divisible by WarpThreadArrangement."); - - static_assert( - !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), - "ShapeInAccesses must be divisible by WarpThreadArrangement."); - - // compute number of warp-level accesses total - using WarpAccessIterations = layout::PitchLinearShape< - ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, - ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided - >; - - // Divide it into the number of warps, first partitioning the strided dimension then the - // contiguous. - static int const kWarpsStrided = - (WarpAccessIterations::kStrided >= kWarpCount - ? kWarpCount - : WarpAccessIterations::kStrided); - - static int const kWarpsContiguous = - (kWarpCount > WarpAccessIterations::kStrided - ? kWarpCount / kWarpsStrided - : 1); - - /// Arrangement of warps within a threadblock-scoped tile - using WarpArrangement = layout::PitchLinearShape< - kWarpsContiguous, kWarpsStrided - >; - }; - - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = layout::PitchLinearShape< - Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, - Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided - >; - - static_assert(Iterations::kCount, - "Number of iterations must be non-zero"); - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = layout::PitchLinearShape< - Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, - Detail::WarpThreadArrangement::kStrided - >; - - /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - int warp_id = (thread_id / Detail::kWarpSize); - int lane_id = (thread_id % Detail::kWarpSize); - - // - // compute warp-level offset - // - - // This is the shape of the entire area covered by a warp's memory access (in units of vectors) - layout::PitchLinearCoord warp_footprint{ - Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, - Detail::WarpThreadArrangement::kStrided * Iterations::kStrided - }; - - // This is the offset of a specific warp (in units of vectors) - layout::PitchLinearCoord warp_offset{ - (warp_id % Detail::kWarpsContiguous), - (warp_id / Detail::kWarpsContiguous) - }; - - // This is the offset of a specific thread within a warp (units of vectors) - layout::PitchLinearCoord thread_offset_in_warp{ - lane_id % Detail::WarpThreadArrangement::kContiguous, - lane_id / Detail::WarpThreadArrangement::kContiguous - }; - - // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = - warp_footprint * warp_offset + thread_offset_in_warp; - - // This is the offset of a thread within a threadblock tile (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ - thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, - thread_offset_in_threadblock_tile_vec.strided() - }; - - return thread_offset_in_threadblock_tile_base; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -/// elements. Warps are arranged based on a stride. -/// -/// This ThreadMap is used by tensor core kernels for NCxHWx layout. -template < - typename Shape_, - int Threads, - typename WarpThreadArrangement_, - int ElementsPerAccess = 1 -> -struct PitchLinearStridedWarpRakedThreadMap { - - /// Tensor coordinate - using TensorCoord = layout::PitchLinearCoord; - - /// Tile shape - using Shape = Shape_; - - /// Number of threads total - static int const kThreads = Threads; - - using WarpThreadArrangement = WarpThreadArrangement_; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ElementsPerAccess; - - /// Base ThreadMap - using BaseThreadMap = PitchLinearWarpRakedThreadMap< - Shape, - kThreads, - WarpThreadArrangement, - kElementsPerAccess - >; - - /// Shape of access by each thread - using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape; - - - struct Detail { - - using WarpThreadArrangement = WarpThreadArrangement_; - - using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations; - - static int const kWarpSize = BaseThreadMap::Detail::kWarpSize; - - static int const kWarpCount = BaseThreadMap::Detail::kWarpCount; - - using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses; - - // Divide it into the number of warps, first partitioning the contiguous dimension then the - // stride. - static int const kWarpsContiguous = - (WarpAccessIterations::kContiguous >= kWarpCount - ? kWarpCount - : WarpAccessIterations::kContiguous); - - static int const kWarpsStrided = - (kWarpCount > WarpAccessIterations::kContiguous - ? kWarpCount / kWarpsContiguous - : 1); - - /// Arrangement of warps within a threadblock-scoped tile - using WarpArrangement = layout::PitchLinearShape< - kWarpsContiguous, kWarpsStrided - >; - - }; - - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = layout::PitchLinearShape< - Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, - Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided - >; - - static_assert(Iterations::kCount, - "Number of iterations must be non-zero"); - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = typename BaseThreadMap::Delta; - - /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - int warp_id = (thread_id / Detail::kWarpSize); - int lane_id = (thread_id % Detail::kWarpSize); - - // - // compute warp-level offset - // - - // This is the shape of the entire area covered by a warp's memory access (in units of vectors) - layout::PitchLinearCoord warp_footprint{ - Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, - Detail::WarpThreadArrangement::kStrided * Iterations::kStrided - }; - - // This is the offset of a specific warp (in units of vectors) - layout::PitchLinearCoord warp_offset{ - (warp_id % Detail::kWarpsContiguous), - (warp_id / Detail::kWarpsContiguous) - }; - - // This is the offset of a specific thread within a warp (units of vectors) - layout::PitchLinearCoord thread_offset_in_warp{ - lane_id % Detail::WarpThreadArrangement::kContiguous, - lane_id / Detail::WarpThreadArrangement::kContiguous - }; - - // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = - warp_footprint * warp_offset + thread_offset_in_warp; - - // This is the offset of a thread within a threadblock tile (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ - thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, - thread_offset_in_threadblock_tile_vec.strided() - }; - - return thread_offset_in_threadblock_tile_base; - } - - -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Transpose the existing ThreadMap. For example, interleaved layout is like -/// congruous in the global memory and crosswise in the shared memory. We need -/// to transpose the coordinates between two. - -template -struct TransposePitchLinearThreadMap { - /// Underlying ThreadMap - using ThreadMap = ThreadMap_; - - /// Tensor coordinate - using TensorCoord = typename ThreadMap::TensorCoord; - - /// Tile shape - using Shape = typename ThreadMap::Shape; - - /// Number of threads total - static int const kThreads = ThreadMap::kThreads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - /// Shape of access by each thread - using ThreadAccessShape = layout::PitchLinearShape; - - /// Internal details made public to facilitate introspection - struct Detail { - /// Fixed arrangement of threads within a warp (units of threads). - using WarpThreadArrangement = WarpThreadArrangement_; - - /// Number of threads per warp - static int const kWarpSize = WarpThreadArrangement::kCount; - - /// Number of participating warps - static int const kWarpCount = kThreads / kWarpSize; - - static_assert(!(Shape::kContiguous % kElementsPerAccess), - "Shape must be divisible by vector length."); - - /// Arrangement of warps within a threadblock-scoped tile - using WarpArrangement = - layout::PitchLinearShape; - }; - - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = - layout::PitchLinearShape; - - static_assert(Iterations::kContiguous == 1, - "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); - - static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = - layout::PitchLinearShape; - - /// Maps thread ID to a coordinate offset within the tensor's logical - /// coordinate space Note this is slightly different from the one of - /// PitchLinearWarpRakedThreadMap. - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - int warp_id = (thread_id / Detail::kWarpSize); - int lane_id = (thread_id % Detail::kWarpSize); - - // - // compute warp-level offset - // - - // This is the shape of the entire area covered by a warp's memory access - // (in units of vectors) - layout::PitchLinearCoord warp_footprint{ - Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, - Detail::WarpThreadArrangement::kStrided * Iterations::kStrided}; - - // This is the offset of a specific warp (in units of vectors) - // Note the order of / and %. Also the 2nd operand is kStrided. - layout::PitchLinearCoord warp_offset{ - (warp_id / Detail::WarpArrangement::kStrided), - (warp_id % Detail::WarpArrangement::kStrided)}; - - // This is the offset of a specific thread within a warp (units of vectors) - layout::PitchLinearCoord thread_offset_in_warp{ - lane_id % Detail::WarpThreadArrangement::kContiguous, - lane_id / Detail::WarpThreadArrangement::kContiguous}; - - // This is the offset of a thread within a threadblock tile (units of - // vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = - warp_footprint * warp_offset + thread_offset_in_warp; - - // This is the offset of a thread within a threadblock tile (units of - // elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ - thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, - thread_offset_in_threadblock_tile_vec.strided()}; - - return thread_offset_in_threadblock_tile_base; - } -}; - -template -struct TransposePitchLinearThreadMapSimt { - /// Underlying ThreadMap - using ThreadMap = ThreadMap_; - - /// Tensor coordinate - using TensorCoord = typename ThreadMap::TensorCoord; - - /// Tile shape - using Shape = typename ThreadMap::Shape; - - /// Number of threads total - static int const kThreads = ThreadMap::kThreads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = - layout::PitchLinearShape; - - static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - - static_assert(Iterations::kStrided == 1, - "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); - - /// Shape of access by each thread - using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = - layout::PitchLinearShape; - - - /// Maps thread ID to a coordinate offset within the tensor's logical - /// coordinate space Note this is slightly different from the one of - /// PitchLinearWarpRakedThreadMap. - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - TensorCoord coord = ThreadMap::initial_offset(thread_id); - - return TensorCoord( - coord.strided(), - coord.contiguous() - ); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - - -/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory -/// accesses performed by each warp then distributes warps across them. Warps are striped in the -/// strided dimension and raked across the contiguous dimension. -template < - typename Shape_, /// Overall shape to partition in units of elements - int Threads, /// Number of partiticipation threads - typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp - int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size) -> -struct PitchLinearWarpStripedThreadMap { - - /// Tensor coordinate - using TensorCoord = layout::PitchLinearCoord; - - /// Tile shape - using Shape = Shape_; - - /// Number of threads total - static int const kThreads = Threads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ElementsPerAccess; - - /// Shape of access by each thread - using ThreadAccessShape = layout::PitchLinearShape; - - /// Internal details made public to facilitate introspection - struct Detail { - - /// Fixed arrangement of threads within a warp (units of threads). - using WarpThreadArrangement = WarpThreadArrangement_; - - /// Number of threads per warp - static int const kWarpSize = WarpThreadArrangement::kCount; - - /// Number of participating warps - static int const kWarpCount = kThreads / kWarpSize; - - static_assert( - !(Shape::kContiguous % kElementsPerAccess), - "Shape must be divisible by vector length."); - - /// Compute the 'shape' of the overall tile in units of vectors - using ShapeInAccesses = layout::PitchLinearShape< - Shape::kContiguous / kElementsPerAccess, - Shape::kStrided - >; - - // compute number of warp-level accesses total - using WarpAccessIterations = layout::PitchLinearShape< - ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, - ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided - >; - - // Divide it into the number of warps, first partitioning the strided dimension then the - // contiguous. - static int const kWarpsStrided = - (WarpAccessIterations::kStrided >= kWarpCount - ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); - - static int const kWarpsContiguous = - (kWarpCount > WarpAccessIterations::kStrided ? - WarpAccessIterations::kContiguous / kWarpsStrided : 1); - - /// Arrangement of warps within a threadblock-scoped tile - using WarpArrangement = layout::PitchLinearShape< - kWarpsContiguous, kWarpsStrided - >; - }; - - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = layout::PitchLinearShape< - Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, - Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided - >; - - static_assert(Iterations::kCount, - "Number of iterations must be non-zero"); - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = layout::PitchLinearShape< - Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, - Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided - >; - - /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - int warp_id = (thread_id / Detail::kWarpSize); - int lane_id = (thread_id % Detail::kWarpSize); - - // - // compute warp-level offset - // - - // This is the shape of the entire area covered by a warp's memory access (in units of vectors) - layout::PitchLinearCoord warp_footprint{ - Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, - Detail::WarpThreadArrangement::kStrided - }; - - // This is the offset of a specific warp (in units of vectors) - layout::PitchLinearCoord warp_offset{ - (warp_id % Detail::kWarpsContiguous), - (warp_id / Detail::kWarpsContiguous) - }; - - // This is the offset of a specific thread within a warp (units of vectors) - layout::PitchLinearCoord thread_offset_in_warp{ - lane_id % Detail::WarpThreadArrangement::kContiguous, - lane_id / Detail::WarpThreadArrangement::kContiguous - }; - - // This is the offset of a thread within a threadblock tile (units of vectors) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = - warp_footprint * warp_offset + thread_offset_in_warp; - - // This is the offset of a thread within a threadblock tile (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ - thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, - thread_offset_in_threadblock_tile_vec.strided() - }; - - return thread_offset_in_threadblock_tile_base; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous -/// dimension then along the strided dimension, while each thread access a 2D thread-tile. -/// -/// The tile must be divisible by the thread count such that all threads may execute the same -/// number of iterations with the same delta to exhaustively cover the tile. -/// -/// This class satisfies the "RegularThreadMapping" concept. -template < - typename Shape_, - int Threads, - typename ThreadTileShape -> -struct PitchLinear2DThreadTileStripminedThreadMap; - - -template < - typename Shape_, - int Threads -> -struct PitchLinear2DThreadTileStripminedThreadMap >{ - - /// Tensor coordinate - using TensorCoord = layout::PitchLinearCoord; - - /// Tile shape - using Shape = Shape_; - - /// Access Shape of each thread - using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>; - //using ThreadAccessShape = ThreadTileShape; - - /// Number of threads total - static int const kThreads = Threads; - - /// Extract length of each access from Layout - static int const kElementsPerAccess = ThreadAccessShape::kContiguous; - - static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)"); - - /// Internal implementation details - struct Detail { - - static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4"); - - static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), ""); - - static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)), - "Shape must be divisible thread count * accesses per thread."); - - /// Shape of the tile in units of vectors - using ShapeVec = layout::PitchLinearShape< - Shape::kContiguous / ThreadAccessShape::kContiguous, - Shape::kStrided / ThreadAccessShape::kStrided - >; - - static_assert( - (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || - (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), - "Shape must be divisible by number of iterations of each thread." - ); - }; - - /// Number of iterations by each thread - using Iterations = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - 1, - // Redo the comparison here to work around divide by zero compiler - // error. The compiler evaluates both path of platform::conditional. - (Threads >= Detail::ShapeVec::kContiguous - ? Detail::ShapeVec::kStrided / - (kThreads / Detail::ShapeVec::kContiguous) - : 0)>, - layout::PitchLinearShape>::type; - - /// Interval between accesses along each dimension of the tensor's logical coordinate space - /// (in units of Elements) - using Delta = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - Shape::kContiguous, - kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous - >, - layout::PitchLinearShape< - kThreads * ThreadAccessShape::kContiguous, - 1 - > - >::type; - - /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space - /// (in units of Elements) - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - return TensorCoord( - (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous, - (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided); - } -}; - -/// Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping -template -struct TransposePitchLinearThreadMap2DThreadTile { - /// Underlying ThreadMap - using ThreadMap = ThreadMap_; - - /// Tensor coordinate - using TensorCoord = typename ThreadMap::TensorCoord; - - /// Tile shape - using Shape = typename ThreadMap::Shape; - - /// Number of threads total - static int const kThreads = ThreadMap::kThreads; - - /// Extract vector length from Layout - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - - static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); - ///< Iterations along each dimension (concept: PitchLinearShape) - using Iterations = - layout::PitchLinearShape; - - static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - - /// Shape of access by each thread - using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; - - ///< Delta between accesses (units of elements, concept: PitchLinearShape) - using Delta = - layout::PitchLinearShape; - - - /// Maps thread ID to a coordinate offset within the tensor's logical - /// coordinate space Note this is slightly different from the one of - /// PitchLinearWarpRakedThreadMap. - CUTLASS_HOST_DEVICE - static TensorCoord initial_offset(int thread_id) { - - TensorCoord coord = ThreadMap::initial_offset(thread_id); - return TensorCoord( - coord.strided(), - coord.contiguous() - ); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h deleted file mode 100644 index 508cad846e6d6b819c26570e5dcae9844f712089..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Basic copy routines for tensor views -*/ - -#pragma once - -namespace cutlass { -namespace transform { -namespace thread { - -/// Transforms a fragment by doing a transpose -template < - int ElementCount, - typename TransposeShape, - typename Element -> struct Transpose; - -/// Specialization for int8_t 4x4 transpose -template -struct Transpose , int8_t> { - - static const int kElementCount = ElementCount_; - using TransposeShape = layout::PitchLinearShape<4,4>; - using Element = int8_t; - using Fragment = cutlass::Array; - - static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose"); - - CUTLASS_DEVICE - void transform(Fragment& dst, Fragment& src) { - - // Expose src/dst as int arrays. - int* src_int = reinterpret_cast(&src); - int* dst_int = reinterpret_cast(&dst); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){ - - int const i0 = 4 * i + 0; - int const i1 = 4 * i + 1; - int const i2 = 4 * i + 2; - int const i3 = 4 * i + 3; - - int a0 = src_int[i0]; - int a1 = src_int[i1]; - int a2 = src_int[i2]; - int a3 = src_int[i3]; - - int b0, b1, b2, b3, c0; - b0 = __byte_perm(a0, a1, 0x0040); - c0 = __byte_perm(a2, a3, 0x0040); - b0 = __byte_perm(b0, c0, 0x5410); - - b1 = __byte_perm(a0, a1, 0x0051); - c0 = __byte_perm(a2, a3, 0x0051); - b1 = __byte_perm(b1, c0, 0x5410); - - b2 = __byte_perm(a0, a1, 0x0062); - c0 = __byte_perm(a2, a3, 0x0062); - b2 = __byte_perm(b2, c0, 0x5410); - - b3 = __byte_perm(a0, a1, 0x0073); - c0 = __byte_perm(a2, a3, 0x0073); - b3 = __byte_perm(b3, c0, 0x5410); - - dst_int[i0] = b0; - dst_int[i1] = b1; - dst_int[i2] = b2; - dst_int[i3] = b3; - } - } -}; - -} // namespace thread -} // namespace layout -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h deleted file mode 100644 index 3977af529124dc3db34610046b72145c2a14bf00..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" - -namespace cutlass { -namespace transform { -namespace thread { - -namespace UnaryTransform { - struct Identity; ///< None (i.e., identity) - struct Conjugate; ///< Complex conjugate -} - -/// Element-wise unary operator that transforms one element of a fragment at a time -template< - typename FragmentIn, ///< Input Fragment - typename FragmentOut,///< Output Fragment - typename Transform> ///< Unary transform operator -class UnaryOp -{ - public: - CUTLASS_DEVICE - static FragmentOut execute(FragmentIn &in) - { - static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); - static_assert(platform::is_same::value || - platform::is_same::value, - "Unary Operator not supported."); - - FragmentOut out; - if (platform::is_same::value ) - { - CUTLASS_PRAGMA_UNROLL - for (int i=0; i < FragmentIn::kElements; ++i){ - out[i] = static_cast(in[i]); - } - } - else if (platform::is_same::value ) - { - for (int i=0; i < FragmentIn::kElements; ++i){ - out[i] = conj(static_cast(in[i])); - } - } - return out; - } -}; - -template -class UnaryOp -{ - public: - CUTLASS_DEVICE - static FragmentIn execute(FragmentIn &in) - { - static_assert(platform::is_same::value || - platform::is_same::value, - "Unary Operator not supported."); - - if (platform::is_same::value ) - { - return in; - } - else if (platform::is_same::value ) - { - for(int i=0; i < FragmentIn::kElements; ++i){ - in[i] = conj(in[i]); - } - } - return in; - } - }; - } - } -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h deleted file mode 100644 index bd717d678f8234b9fd39f7d22c4de5c231da4c42..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h +++ /dev/null @@ -1,199 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Ell iterator for matrix of indices (ellColInd matrix) -*/ - -#pragma once - -namespace cutlass { -namespace transform { -namespace threadblock { - -namespace ell{ - -constexpr unsigned int SmemPow = 8; -constexpr unsigned int SmemStages = 2; -constexpr unsigned int SmemSize = 1 << SmemPow; -constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); - -class SharedStorage{ - public: - Array array; -}; - -class Iterator{ - public: - using Layout = layout::PitchLinear; - using LongIndex = typename Layout::LongIndex; - - private: - const int *gmem_col_idx_; - int *smem_col_idx_; - const int block_size_; - const int base_idx_; - const int k_shape_; - const int ell_increment_; - const int array_length_; - int col_idx_base_; - int residue_; - int counter_; - - int pow2_; - int residue_shape_; - - int smem_offset_; - int smem_stage_; - int gmem_offset_; - - int lane_; - - bool is_pow2_; - bool is_residue_tile_; - - public: - CUTLASS_DEVICE - void load_ell_indices(){ - for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; - } - gmem_offset_ += SmemSize; - smem_stage_ ^= 1; - } - - CUTLASS_DEVICE - Iterator( - SharedStorage& shared_storage_base, - const int* col_idx, - const int& block_size, - const int& base_idx, - const int k_shape, - const int& problem_size_k, - const int& ell_stride, - const int& thread_idx) - : residue_(0), - counter_(0), - smem_offset_(0), - smem_stage_(0), - gmem_offset_(0), - block_size_(block_size), - base_idx_(base_idx), - k_shape_(k_shape), - ell_increment_(ell_stride * block_size), - array_length_((problem_size_k + block_size_ - 1) / block_size_), - residue_shape_(problem_size_k % k_shape_), - is_residue_tile_(residue_shape_ != 0), - smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), - gmem_col_idx_(const_cast(col_idx)), - lane_(thread_idx % 32) { - - load_ell_indices(); - __syncthreads(); - - is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); - if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; - - col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; - - pow2_ = 0; - while(block_size_ >> (pow2_ + 1)) ++pow2_; - } - - CUTLASS_DEVICE - int get_blocksize(){ - return block_size_; - } - - CUTLASS_DEVICE - Iterator &operator++(){ - if(is_residue_tile_){ - residue_ += residue_shape_; - is_residue_tile_ = false; - } else { - residue_ += k_shape_; - } - - if(residue_ < block_size_){ - return *this; - } - - if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) - load_ell_indices(); - - if(residue_ == block_size_){ - ++smem_offset_; - counter_ += ell_increment_; - residue_ = 0; - col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; - return *this; - } - - if(is_pow2_){ - smem_offset_ += residue_ >> pow2_; - counter_ += (residue_ >> pow2_) * ell_increment_; - residue_ = residue_ & ((1 << pow2_) - 1); - } - else { - smem_offset_ += residue_ / block_size_; - counter_ += (residue_ / block_size_) * ell_increment_; - residue_ %= block_size_; - } - - col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; - - return *this; - } - - CUTLASS_DEVICE - LongIndex get_offset(const int& idx) { - int num_jump_tiles; - if(is_pow2_) - num_jump_tiles = (idx + residue_) >> pow2_; - else - num_jump_tiles = (idx + residue_) / block_size_; - - int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); - return tmp - num_jump_tiles * ell_increment_; - } - - CUTLASS_DEVICE - LongIndex get_offset_fast() { - return col_idx_base_; - } -}; - -} -} -} -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h deleted file mode 100644 index 3676c2339067f9eaad667e11e0d798ae3f4d5c95..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h +++ /dev/null @@ -1,1350 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// EllPredicatedTileAccessIterator -/// -template -class EllPredicatedTileAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -/// -template -class EllPredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static int const kPredicatesPerByte = 4; - static int const kPredicatesPerWord = 4 * kPredicatesPerByte; - - static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; - - /// Number of 32b words containing predicates - static int const kPredicateByteCount = - (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; - static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; - - static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; - - static_assert(kPredicateWordCount <= 4, "Too many predicates."); - - /// Predicate vector stores mask to guard accesses - using Mask = Array; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend EllPredicatedTileAccessIterator; - - private: - /// stride of pitch-linear layout (units of Element) - LongIndex stride_; - /// amount (in byte) to increment pointer to move to next access along - /// strided dimension - LongIndex inc_strided_; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - LongIndex inc_next_; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_; - - public: - - // Default ctor - CUTLASS_HOST_DEVICE - Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : stride_(layout.stride(0)) { - inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * - sizeof_bits::value / 8; - - if (kAdvanceRank) { - // advance along strided dimension - inc_advance_ = - Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; - } else { - // advance along contiguous dimension - inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; - } - - inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * LongIndex(stride_) * - sizeof_bits::value / 8; - }; - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - /// Guard predicates - uint32_t predicates_[kPredicateWordCount]; - - /// Size of tensor - TensorCoord extent_; - - /// Initial offset for each thread - TensorCoord thread_offset_; - - /// Offset to the first steady-state tile - TensorCoord residue_offset_; - - /// Initial offset to define ELL block - TensorCoord ell_offset_; - - /// Used for out-of-order visitation - bool is_residue_tile_; - - /// Iteration along vectors implied by the thread map - int iteration_vector_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_DEVICE - void compute_predicates_( - /// Extent of the matrix window - TensorCoord extent, - /// optionally, simplify predicate calculation during 'steady state' phase - bool is_steady_state = false) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0u; - } - - CUTLASS_PRAGMA_UNROLL - for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { - - int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int c = access_residual / kAccessesPerVector; - int v = access_residual % kAccessesPerVector; - - TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, - s * ThreadMap::Delta::kStrided); - - TensorCoord coord = thread_offset_ + iteration_coord; - - bool guard; - - if (is_steady_state) { - if (kAdvanceRank == 0) { - guard = (coord.strided() < extent.strided()); - } else { - guard = (coord.contiguous() < extent.contiguous()); - } - } else { - guard = (coord.strided() < extent.strided() && - coord.contiguous() < extent.contiguous()); - } - - int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); - - } - - } - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - pointer_(reinterpret_cast( - const_cast(pointer))), - extent_(extent), - is_residue_tile_(true) { - - TensorCoord residue_extent; - if (kAdvanceRank) { - - typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; - if (!residue_size) { - residue_size = Shape::kStrided; - } - - residue_offset_ = make_Coord(0, residue_size); - residue_extent = make_Coord( - extent_.contiguous(), - min(threadblock_offset.strided() + residue_size, extent_.strided()) - ); - } else { - - typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; - if (!residue_size) { - residue_size = Shape::kContiguous; - } - - residue_offset_ = make_Coord(residue_size, 0); - - residue_extent = make_Coord( - min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), - extent_.strided() - ); - } - - // Per-thread offset in logical coordinates of tensor - ell_offset_ = ThreadMap::initial_offset(thread_id); - thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); - - // update internal pointers - Layout layout(params_.stride_); - add_pointer_offset(layout(thread_offset_)); - - compute_predicates_(residue_extent, false); - - set_iteration_index(0); - } - - /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id) - : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += sizeof_bits::value * pointer_offset / 8; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - if (is_residue_tile_) { - - thread_offset_ += residue_offset_; - - Layout layout(params_.stride_); - add_pointer_offset(layout(residue_offset_)); - - compute_predicates_(extent_, true); - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } else { - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } - is_residue_tile_ = false; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast( - pointer_ + - iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; - } - - /// Returns a k_location - CUTLASS_HOST_DEVICE - int get_k() const { - if(kAdvanceRank){ //strided - return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; - }else{ - return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; - } - } - - CUTLASS_HOST_DEVICE - int get_stride() const { - if(kAdvanceRank) - return params_.stride_; - else - return 1; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator &operator++() { - - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - - iteration_vector_ = 0; - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - pointer_ += params_.inc_strided_; - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; - - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator operator++(int) { - EllPredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = enable ? 0u : predicates_[i]; - } - - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0xffffffff; - } - - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = mask[i]; - } - - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] = predicates_[i]; - } - } - - /// add mask for small tiles in ELL - CUTLASS_DEVICE - void ell_add_mask(int blocksize) { - - Mask mask; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] = 0u; - } - - CUTLASS_PRAGMA_UNROLL - for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { - - int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int c = access_residual / kAccessesPerVector; - int v = access_residual % kAccessesPerVector; - - TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, - s * ThreadMap::Delta::kStrided); - - TensorCoord coord = ell_offset_ + iteration_coord; - - bool guard; - - if (kAdvanceRank == 0) { - guard = (coord.strided() < blocksize); - } else { - guard = (coord.contiguous() < blocksize); - } - - int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); - - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] &= predicates_[i]; - } - set_mask(mask); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - - int pred_idx = - iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; - return pred; - - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class EllPredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){}; - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), - threadblock_offset.column())) {} - - /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - CUTLASS_HOST_DEVICE - int get_k() const { - return iterator_.get_k(); - } - - CUTLASS_HOST_DEVICE - int get_stride() const { - return iterator_.get_stride(); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator operator++(int) { - EllPredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class EllPredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){}; - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - CUTLASS_HOST_DEVICE - int get_k() const { - return iterator_.get_k(); - } - - CUTLASS_HOST_DEVICE - int get_stride() const { - return iterator_.get_stride(); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator operator++(int) { - EllPredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. -/// It is mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// - -template -class EllPredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::ColumnMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileAccessIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, - AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row() * kInterleavedK, - extent.column() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.row() * kInterleavedK, - threadblock_offset.column() / kInterleavedK)) {} - - /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - CUTLASS_HOST_DEVICE - int get_k() const { - return iterator_.get_k(); - } - - CUTLASS_HOST_DEVICE - int get_stride() const { - return iterator_.get_stride(); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator operator++(int) { - EllPredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { return iterator_.valid(); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. -/// It is mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class EllPredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::RowMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileAccessIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, - AccessType>; - - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column() * kInterleavedK, - extent.row() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.column() * kInterleavedK, - threadblock_offset.row() / kInterleavedK)) {} - - /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - CUTLASS_HOST_DEVICE - int get_k() const { - return iterator_.get_k(); - } - - CUTLASS_HOST_DEVICE - int get_stride() const { - return iterator_.get_stride(); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileAccessIterator operator++(int) { - EllPredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { return iterator_.valid(); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h deleted file mode 100644 index e377bba4c454267737bffda73b1dff7572174ee7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h +++ /dev/null @@ -1,1315 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined -*/ - -#pragma once - -#include "cutlass/arch/memory.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" - -#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" -#include "cutlass/transform/threadblock/ell_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// EllPredicatedTileIterator -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -/// Regular tile iterator using a precomputed control structure to minimize register liveness -/// and integer arithmetic. -/// -/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -/// -/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -/// Subsequently, they are assumed to be immutable. -/// -/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -/// -/// Visitation order is intended to first visit a "residual" tile that may be partially full in -/// both the advance dimension and the steady-state dimension. This is assumed to be the last -/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -/// accesses may be performed without updating internal predicates and are efficient in terms of -/// live register state and pointer arithmetic instructions. -/// -/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -/// outside any looping structure to minimize integer arithmetic. -/// -/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -/// the iterator. -/// -/// -/// Example: -/// -/// An efficient pipeline structure may be constructed as follows: -/// -// template -// __global__ void kernel( -// typename Iterator::Params params, -// typename Iterator::Element *ptr, -// TensorCoord extent) { -// -// typename Iterator::Fragment fragment; -// -// TensorCoord threadblock_offset(0, 0); -// -// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -// -// -// fragment = *iter; // load "residue" tile first -// ++iter; // advance to first "steady state" tile and update internal masks -// -// -// #pragma unroll -// for (int i = Remaining - 1; i >= 0; --i) { -// -// f(fragment); -// -// if (!i) { -// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -// } -// -// fragment = *iter; // load tile during "steady state" phase -// ++iter; // advance to next tile - lightweight due to steady-state masks -// } -// } -// -// void host(TensorView view) { -// -// using Iterator = transform::threadblock::EllPredicatedTileIterator; -// -// typename Iterator::Params params(view.layout()); -// -// kernel(params, view.data()); -// } -/// -/// -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - int AccessSize = ThreadMap::kElementsPerAccess -> -class EllPredicatedTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class EllPredicatedTileIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - /// Type used for internal memory accesses - using AccessType = AlignedArray::value / 8)>; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = - EllPredicatedTileAccessIterator; - - static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename TileAccessIterator::Mask; - - /// Iterator for ELL storage - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend EllPredicatedTileIterator; - - private: - /// Parameters object - typename TileAccessIterator::Params params_; - - public: - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) { } - - CUTLASS_HOST_DEVICE - Params() { } - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : address_iterator_(params.params_, pointer, extent, thread_id, - threadblock_offset) {} - - /// Construct a EllPredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator &operator++() { - if (kAdvanceRank) - address_iterator_.add_tile_offset({0, 1}); - else - address_iterator_.add_tile_offset({1, 0}); - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator operator++(int) { - EllPredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Returns a stride - CUTLASS_HOST_DEVICE - int get_stride() const { return address_iterator_.get_stride(); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { address_iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_HOST_DEVICE - void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } - - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - address_iterator_.set_iteration_index(idx); - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, address_iterator_.valid()); - - ++address_iterator_; - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_byte_offset(frag, 0); } - - CUTLASS_DEVICE - void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - address_iterator_.set_iteration_index(idx); - LongIndex ell_offset = 0; - - int k_offset = address_iterator_.get_k(); - ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); - - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - bool is_valid = address_iterator_.valid(); - is_valid = is_valid && (ell_offset >= 0); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, is_valid); - - ++address_iterator_; - } - } - } - } - - CUTLASS_DEVICE - void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { - - LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - address_iterator_.set_iteration_index(idx); - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - bool is_valid = address_iterator_.valid(); - is_valid = is_valid && (ell_offset >= 0); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, is_valid); - - ++address_iterator_; - } - } - } - } - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - if (address_iterator_.valid()) { - *access_ptr = frag_ptr[idx]; - } - ++address_iterator_; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize -> -class EllPredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Iterator for ELL storage - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend EllPredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { - - } - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset ///< Initial offset of threadblock - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) - ) { } - - /// Construct a EllPredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator operator++(int) { - EllPredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Returns a stride - CUTLASS_HOST_DEVICE - int get_stride() const { return iterator_.get_stride(); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// add mask for small tiles in ELL - CUTLASS_HOST_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index(frag, ell_iter); - } - - CUTLASS_DEVICE - void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index_fast(frag, ell_iter); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize -> -class EllPredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Iterator for ELL storage - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend EllPredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { - - }; - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset ///< Initial offset of threadblock - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) - ) { } - - /// Construct a EllPredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator operator++(int) { - EllPredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Returns a stride - CUTLASS_HOST_DEVICE - int get_stride() const { return iterator_.get_stride(); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// add mask for small tiles in ELL - CUTLASS_HOST_DEVICE - void ell_add_mask(int blocksize) { - iterator_.ell_add_mask(blocksize); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index(frag, ell_iter); - } - - CUTLASS_DEVICE - void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index_fast(frag, ell_iter); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped -/// to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// - -template -class EllPredicatedTileIterator, - AdvanceRank, ThreadMap_, AccessSize> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::ColumnMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; - - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Iterator for ELL storage - using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row() * kInterleavedK, - extent.column() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.row() * kInterleavedK, - threadblock_offset.column() / kInterleavedK)) {} - - /// Construct a EllPredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator operator++(int) { - EllPredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Returns a stride - CUTLASS_HOST_DEVICE - int get_stride() const { return iterator_.get_stride(); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_HOST_DEVICE - void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - CUTLASS_DEVICE - void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index(frag, ell_iter); - } - - CUTLASS_DEVICE - void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { - iterator_.load_with_ell_index_fast(frag, ell_iter); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is -/// mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class EllPredicatedTileIterator, - AdvanceRank, ThreadMap_, AccessSize> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::RowMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = EllPredicatedTileIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; - - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend EllPredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column() * kInterleavedK, - extent.row() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.column() * kInterleavedK, - threadblock_offset.row() / kInterleavedK)) {} - - /// Construct a EllPredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : EllPredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - EllPredicatedTileIterator operator++(int) { - EllPredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Returns a stride - CUTLASS_HOST_DEVICE - int get_stride() const { return iterator_.get_stride(); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// add mask for small tiles in ELL - CUTLASS_HOST_DEVICE - void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h deleted file mode 100644 index dab597c835ced1a4f070858b26da3007d268c04e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h +++ /dev/null @@ -1,375 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates calculating the address and predicates to the load of scale and bias vectors. - - This iterator uses masks to guard out-of-bounds accesses. - - It can be used to load the gamma and beta vectors of layernorm which is loop variant. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/conv/threadblock/conv2d_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedScaleBiasVectorAccessIterator -/// -template -class PredicatedScaleBiasVectorAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -/// -template -class PredicatedScaleBiasVectorAccessIterator { - public: - - using ThreadblockShape = ThreadblockShape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kElementsPerAccess = 128 / sizeof_bits::value; - static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; - - using AccessType = AlignedArray; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Internal pointer to first access of tile - BytePointer pointer_; - - TensorCoord thread_offset_; - - int problem_size_k_; - - /// Used for out-of-order visitation - bool is_residue_tile_; - - bool guard_; - - TensorCoord::Index residue_size_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Extent of tensor - int problem_size_k, - /// Pointer to the start of the scale vector - ConstPointer scale_pointer, - /// Pointer to the start of the bias vector - ConstPointer bias_pointer, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) { - pointer_ = (thread_id < kThreads) - ? reinterpret_cast( - const_cast(scale_pointer)) - : reinterpret_cast( - const_cast(bias_pointer)); - - // Per-thread offset in logical coordinates of tensor - int thread_base = (thread_id < kThreads) ? 0 : kThreads; - - problem_size_k_ = problem_size_k; - - is_residue_tile_ = true; - - residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; - - if (residue_size_ == 0) { - residue_size_ = ThreadblockShape::kContiguous; - } - - guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; - - thread_offset_ = - threadblock_offset + - TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); - - set_iteration_index(0); - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - /// Extent of tensor - int problem_size_k, - /// Pointer to start of scale vector - ConstPointer scale_pointer, - /// Pointer to start of scale vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id) - : PredicatedScaleBiasVectorAccessIterator(problem_size_k, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) {} - - /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - - guard_ = threadIdx.x < kThreads * 2; - - TensorCoord offset = is_residue_tile_ ? - TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) - : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); - - thread_offset_ = - thread_offset_ + - offset; - - is_residue_tile_ = false; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - return reinterpret_cast( - pointer_ + - (thread_offset_.contiguous() * sizeof_bits::value / 8)); - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator &operator++() { - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_DEVICE - PredicatedScaleBiasVectorAccessIterator operator++(int) { - PredicatedScaleBiasVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - guard_ &= (!enable); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return guard_; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedScaleBiasVectorAccessIterator { - public: - - using ThreadblockShape = ThreadblockShape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear>; - - using AccessType = typename UnderlyingIterator::AccessType; - static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - ///< Extent of tensor - int problem_size_k, - ///< Pointer to the start of the scale vector - ConstPointer scale_pointer, - ///< Pointer to the start of the bias vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(problem_size_k, scale_pointer, bias_pointer, - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator( - int problem_size_k, ///< Extent of tensor - ConstPointer scale_pointer, ///< Pointer to the start of the scale vector - ConstPointer bias_pointer, ///< Pointer to the start of the bias vector - int thread_id ///< ID of each participating thread - ) - : PredicatedScaleBiasVectorAccessIterator(problem_size_k, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// threadblock tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorAccessIterator operator++(int) { - PredicatedScaleBiasVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h deleted file mode 100644 index e5d9e70d73bfcbdc27ab78bbedea1278c3b25950..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h +++ /dev/null @@ -1,328 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates calculating the address and predicates to the load of scale and bias vectors. - - This iterator uses masks to guard out-of-bounds accesses. - - This can be used to load var and mean vectors in layernorm which is loop invariant. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedScaleBiasVectorIterator -/// -template -class PredicatedScaleBiasVectorIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -/// -template -class PredicatedScaleBiasVectorIterator { - public: - - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kElementsPerAccess = 1; - - using AccessType = AlignedArray; - - static int const kIterations = WarpShape::kContiguous / 8; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; - - private: - // - // Data members - // - - /// Internal pointer to first access of tile - ConstPointer scale_pointer_; - ConstPointer bias_pointer_; - - /// Size of tensor - int problem_size_; - - int32_t thread_offset_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - /// Extent of tensor - int problem_size, - /// Pointer to the start of the scale vector - ConstPointer scale_pointer, - /// Pointer to the start of the bias vector - ConstPointer bias_pointer, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : problem_size_(problem_size), - scale_pointer_(scale_pointer), - bias_pointer_(bias_pointer) { - - thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; - } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - /// Extent of tensor - int problem_size, - /// Pointer to start of scale vector - ConstPointer scale_pointer, - /// Pointer to start of scale vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id) - : PredicatedScaleBiasVectorIterator(problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - - thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - frag.fill(__float2half2_rn(0.0f)); - __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); - - // load scale - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - - cutlass::arch::global_load< - __half, - sizeof(AccessType) - >( - frag_ptr[c * 2].x, - scale_pointer_ + thread_offset_ + c * 8, - (thread_offset_ + c * 8) < problem_size_ - ); - } - - // load bias - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - - cutlass::arch::global_load< - __half, - sizeof(AccessType) - >( - frag_ptr[c * 2 + 1].x, - bias_pointer_ + thread_offset_ + c * 8, - (thread_offset_ + c * 8) < problem_size_ - ); - } - - // duplicate scale - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - frag_ptr[c * 2].y = frag_ptr[c * 2].x; - } - - // duplicate bias - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedScaleBiasVectorIterator { - public: - - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedScaleBiasVectorIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear>; - - using AccessType = typename UnderlyingIterator::AccessType; - static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; - using Fragment = typename UnderlyingIterator::Fragment; - - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - ///< Extent of tensor - int problem_size, - ///< Pointer to the start of the scale vector - ConstPointer scale_pointer, - ///< Pointer to the start of the bias vector - ConstPointer bias_pointer, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(problem_size, scale_pointer, bias_pointer, - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedScaleBiasVectorIterator( - int problem_size, ///< Extent of tensor - ConstPointer scale_pointer, ///< Pointer to the start of the scale vector - ConstPointer bias_pointer, ///< Pointer to the start of the bias vector - int thread_id ///< ID of each participating thread - ) - : PredicatedScaleBiasVectorIterator(problem_size, - scale_pointer, bias_pointer, - thread_id, make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// threadblock tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - iterator_.load(frag); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h deleted file mode 100644 index 3640709868602584f93e3409a251c0baff19d18d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ /dev/null @@ -1,2118 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates calculating the address and predicates to the load of tiles - from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses. The first tile this - iterator visits maybe partial, then the remaining tiles are complete. So, we - only need to compute the predicates twice, once before the first tile and - once for the remaining full tiles which can share the same predicates. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/permute.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileAccessIteratorPredicates -/// -template -class PredicatedTileAccessIteratorPredicates { - public: - using Shape = Shape_; - using Element = Element_; - using Layout = Layout_; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorCoord = typename Layout::TensorCoord; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static int const kPredicatesPerByte = 4; - static int const kPredicatesPerWord = 4 * kPredicatesPerByte; - - static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; - - /// Number of 32b words containing predicates - static int const kPredicateByteCount = - (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; - static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; - - static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; - - static_assert(kPredicateWordCount <= 4, "Too many predicates."); - - /// Predicate vector stores mask to guard accesses - using Mask = Array; - -// private: - /// Guard predicates - uint32_t predicates_[kPredicateWordCount]; - - /// Size of tensor - TensorCoord extent_; - - /// Initial offset for each thread - TensorCoord thread_offset_; - - /// Offset to the first steady-state tile - TensorCoord residue_offset_; - - /// Iteration along vectors implied by the thread map - int iteration_vector_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_DEVICE - void compute_predicates_( - /// Extent of the matrix window - TensorCoord extent, - /// optionally, simplify predicate calculation during 'steady state' phase - bool is_steady_state = false) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0u; - } - - CUTLASS_PRAGMA_UNROLL - for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { - - int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int c = access_residual / kAccessesPerVector; - int v = access_residual % kAccessesPerVector; - - TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, - s * ThreadMap::Delta::kStrided); - - TensorCoord coord = thread_offset_ + iteration_coord; - - bool guard; - - if (is_steady_state) { - if (kAdvanceRank == 0) { - guard = (coord.strided() < extent.strided()); - } else { - guard = (coord.contiguous() < extent.contiguous()); - } - } else { - guard = (coord.strided() < extent.strided() && - coord.contiguous() < extent.contiguous()); - } - - int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); - - } - - } - - CUTLASS_HOST_DEVICE - void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { - - TensorCoord residue_extent; - if (kAdvanceRank) { - - typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; - if (!residue_size) { - residue_size = Shape::kStrided; - } - - residue_offset_ = make_Coord(0, residue_size); - residue_extent = make_Coord( - extent_.contiguous(), - min(threadblock_offset.strided() + residue_size, extent_.strided()) - ); - } else { - - typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; - if (!residue_size) { - residue_size = Shape::kContiguous; - } - - residue_offset_ = make_Coord(residue_size, 0); - - residue_extent = make_Coord( - min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), - extent_.strided() - ); - } - - // Per-thread offset in logical coordinates of tensor - thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); - - compute_predicates_(residue_extent, false); - - set_iteration_index(0); - } - - /// Default constructor - PredicatedTileAccessIteratorPredicates() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorPredicates( - /// Extent of tensor - TensorCoord extent) - : extent_(extent) { - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorPredicates &operator++() { - - return *this; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = enable ? 0u : predicates_[i]; - } - - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0xffffffff; - } - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = mask[i]; - } - - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] = predicates_[i]; - } - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() const { - - - int pred_idx = - iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; - return pred; - - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileAccessIterator -/// -template -class PredicatedTileAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -/// -template -class PredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< - Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static bool constexpr Permute = !platform::is_same::value - && !platform::is_same>::value; - - using Mask = typename UnderlyingPredicates::Mask; - - /// Uses a non-template class - struct Params : PredicatedTileAccessIteratorParams { - - using Base = PredicatedTileAccessIteratorParams; - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : - Base(layout.stride(0), - MakePredicatedTileAccessIteratorDesc()() - ) { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) : - Base(base) { } - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - UnderlyingPredicates the_predicates; - - /// Parameters object with precomputed internal state - Params params_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - /// Used for out-of-order visitation - bool is_residue_tile_; - - /// Below is used when Gather is turned on. We need to record strided_offset - /// and contiguous_offset separated to compute the offset by using - /// - /// offset = contiguous_offset + indices[strided_offset] - - /// Gather indices - int const *indices_; - - /// Function to perform layout permutation and offset computation - PermuteLayout permute_layout_; - - /// Tracks thread's coordinate offset in the matrix for current tile. - /// This is only used in the following cases: - /// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_) - /// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed) - TensorCoord coord_offset_; - - private: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_DEVICE - void compute_predicates_( - /// Extent of the matrix window - TensorCoord extent, - /// optionally, simplify predicate calculation during 'steady state' phase - bool is_steady_state = false) { - the_predicates.compute_predicates_(extent, is_steady_state); - } - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - /// Gather indices - int const *indices = nullptr) - : params_(params), - pointer_(reinterpret_cast( - const_cast(pointer))), - the_predicates(extent), - is_residue_tile_(true), - indices_(indices), - permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) { - - the_predicates.set_predicates(thread_id, threadblock_offset); - - if (Gather) { - assert(indices_); - } - - // update internal pointers - Layout layout(params_.stride_); - - if (!Gather && !Permute) { - add_pointer_offset(layout(the_predicates.thread_offset_)); - } else { - coord_offset_ = the_predicates.thread_offset_; - if (!Permute) { - add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0))); - } - } - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - the_predicates.set_iteration_index(index); - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += sizeof_bits::value * pointer_offset / 8; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - if (is_residue_tile_) { - - the_predicates.thread_offset_ += the_predicates.residue_offset_; - - the_predicates.compute_predicates_(the_predicates.extent_, true); - - Layout layout(params_.stride_); - - if (!Gather && !Permute) { - add_pointer_offset(layout(the_predicates.residue_offset_)); - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); - pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits::value / 8; - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); - pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits::value / 8; - } - } else { - coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank); - if (!Permute) { - add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0))); - add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank))); - } else { - coord_offset_.contiguous() = the_predicates.thread_offset_.contiguous() + Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank)); - } - } - } else { - if (!Gather && !Permute) { - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } else { - coord_offset_.strided() += Shape::kStrided * tile_offset.strided(); - if (!Permute) { - add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); - } else { - coord_offset_.contiguous() += Shape::kContiguous * tile_offset.contiguous(); - } - } - } - - is_residue_tile_ = false; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - if (Gather || Permute) - { - if (!valid()) { - return nullptr; - } - - Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements; - Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; - if (Gather) { - coord_strided = indices_[coord_strided]; - } - - LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig); - return reinterpret_cast(pointer_ + OffsetBytes(offset)); - } - - return reinterpret_cast( - pointer_ + - the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - - the_predicates.operator++(); - - ++the_predicates.iteration_vector_; - if (the_predicates.iteration_vector_ < kAccessesPerVector) { - return *this; - } - - the_predicates.iteration_vector_ = 0; - ++the_predicates.iteration_contiguous_; - - if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - - // Enter here only if (iteration_contiguous_ == ThreadMap::Iteration::kContiguous) - the_predicates.iteration_contiguous_ = 0; - ++the_predicates.iteration_strided_; - - if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { - if (!Gather && !Permute) { - pointer_ += params_.inc_strided_; - } - - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - the_predicates.iteration_strided_ = 0; - - if (!Gather && !Permute) { - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; - } - - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - the_predicates.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - the_predicates.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - the_predicates.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - the_predicates.get_mask(mask); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() const { - return the_predicates.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, - Gather, PermuteLayout>; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){}; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), - threadblock_offset.column()), - indices) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, - Gather, PermuteLayout>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){}; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset, - /// Gather indices - int const *indices = nullptr) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row()), - indices) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for affine rank 2 data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false, - layout::NoPermute> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRankN<2>; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< - Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingPredicates::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend PredicatedTileAccessIterator; - - private: - /// stride of pitch-linear layout (units of Element) - Coord stride_; - /// amount (in byte) to increment pointer to move to next access along - /// contiguous dimension - LongIndex inc_contiguous_; - /// amount (in byte) to increment pointer from first access of current - /// contiguous dimension to first access of next one. - LongIndex inc_strided_; - /// amount (in byte) to increment pointer from last access of current - /// contiguous dimension to first access of next one. - LongIndex inc_next_strided_; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - LongIndex inc_next_; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_; - - public: - - // Default ctor - CUTLASS_HOST_DEVICE - Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { - inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * - sizeof_bits::value / 8; - - inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * - sizeof_bits::value / 8; - - inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; - - if (kAdvanceRank) { - // advance along strided dimension - inc_advance_ = - Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; - } else { - // advance along contiguous dimension - inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; - } - - inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; - }; - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - // - // Data members - // - - /// Parameters object with precomputed internal state - Params params_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - UnderlyingPredicates the_predicates; - - /// Used for out-of-order visitation - bool is_residue_tile_; - - private: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_DEVICE - void compute_predicates_( - /// Extent of the matrix window - TensorCoord extent, - /// optionally, simplify predicate calculation during 'steady state' phase - bool is_steady_state = false) { - the_predicates.compute_predicates_(extent, is_steady_state); - } - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : params_(params), - pointer_(reinterpret_cast( - const_cast(pointer))), - the_predicates(extent), - is_residue_tile_(true) { - - the_predicates.set_predicates(thread_id, threadblock_offset); - - // update internal pointers - Layout layout(params_.stride_); - add_pointer_offset(layout(the_predicates.thread_offset_)); - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += sizeof_bits::value * pointer_offset / 8; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - if (is_residue_tile_) { - - the_predicates.thread_offset_ += the_predicates.residue_offset_; - - Layout layout(params_.stride_); - add_pointer_offset(layout(the_predicates.residue_offset_)); - - the_predicates.compute_predicates_(the_predicates.extent_, true); - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); - pointer_ += Shape::kContiguous * tile_offset[0]; - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); - pointer_ += Shape::kStrided * tile_offset[1]; - } - } else { - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); - pointer_ += Shape::kContiguous * tile_offset[0]; - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); - pointer_ += Shape::kStrided * tile_offset[1]; - } - } - is_residue_tile_ = false; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - the_predicates.operator++(); - ++the_predicates.iteration_vector_; - if (the_predicates.iteration_vector_ < kAccessesPerVector) { - return *this; - } - - the_predicates.iteration_vector_ = 0; - ++the_predicates.iteration_contiguous_; - - if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - pointer_ += params_.inc_contiguous_; - return *this; - } - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - the_predicates.iteration_contiguous_ = 0; - ++the_predicates.iteration_strided_; - - if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { - pointer_ += params_.inc_next_strided_; - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - the_predicates.iteration_strided_ = 0; - - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { the_predicates.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { the_predicates.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return the_predicates.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRank2ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - // Map to the underlying AffineRankN<2> layout - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given an AffineRankN<2> tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; - }; - - private: - // - // Data members - // - - /// Underlying AffineRankN<2> tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), - threadblock_offset.column())) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRank2RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - // Map to the underlying AffineRankN<2> layout - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, Element, - layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given an AffineRankN<2> tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; - }; - - private: - // - // Data members - // - - /// Underlying AffineRankN<2> tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -/// It is mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// - -template -class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false, - layout::NoPermute> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::ColumnMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, - AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row() * kInterleavedK, - extent.column() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.row() * kInterleavedK, - threadblock_offset.column() / kInterleavedK)) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { return iterator_.valid(); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for row-major interleaved data. -// It is mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false, - layout::NoPermute> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::RowMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, - AccessType>; - - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileAccessIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column() * kInterleavedK, - extent.row() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.column() * kInterleavedK, - threadblock_offset.row() / kInterleavedK)) {} - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { return iterator_.valid(); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h deleted file mode 100644 index 93eac72e40ddf6b0f3d268957873417e5d5a442f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h +++ /dev/null @@ -1,834 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates calculating the address and predicates to the load of tiles - from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last - "residue" tile first, with the objective of minimizing predicate mask updates - during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileAccessIterator2dThreadTile -/// -template -class PredicatedTileAccessIterator2dThreadTile; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -/// -template -class PredicatedTileAccessIterator2dThreadTile { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kPredicatesPerByte = 4; - static int const kPredicatesPerWord = 4 * kPredicatesPerByte; - - /// Number of 32b words containing predicates - static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte; - static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; - - static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; - - static_assert(kPredicateWordCount <= 4, "Too many predicates."); - - /// Predicate vector stores mask to guard accesses - using Mask = Array; - - /// Uses a non-template class - struct Params : PredicatedTileAccessIteratorParams { - - public: - friend PredicatedTileAccessIterator2dThreadTile; - - using Base = PredicatedTileAccessIteratorParams; - - // Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : - Base(layout.stride(0), - MakePredicatedTileAccessIteratorDesc()() - ) { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) : - Base(base) { } - }; - - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - /// Guard predicates - uint32_t predicates_[kPredicateWordCount]; - - /// Size of tensor - TensorCoord extent_; - - /// Initial offset for each thread - TensorCoord thread_offset_; - - /// Index of residue tile - int residue_tile_idx_; - - /// Used for out-of-order visitation - bool is_residue_tile_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - /// Tracks iterations within the thread loop - int iteration_thread_; - - private: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_HOST_DEVICE - void compute_predicates_( - /// optionally, simplify predicate calculation during 'steady state' phase - bool is_steady_state = false) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0u; - } - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) { - - TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous, - ts + s * ThreadMap::Delta::kStrided); - - TensorCoord coord = thread_offset_ + iteration_coord; - - bool guard; - - if (is_steady_state) { - if (kAdvanceRank == 0) { - guard = (coord.strided() < extent_.strided()); - } else { - guard = (coord.contiguous() < extent_.contiguous()); - } - } else { - guard = (coord.strided() < extent_.strided() && - coord.contiguous() < extent_.contiguous()); - } - - int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); - - } - } - } - - } - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - pointer_(reinterpret_cast( - const_cast(pointer))), - extent_(extent), - is_residue_tile_(true) { - - - TensorCoord residue_offset; - if (kAdvanceRank) { - residue_tile_idx_ = - (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / - Shape::kStrided; - residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided); - } else { - residue_tile_idx_ = - (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / - Shape::kContiguous; - residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0); - } - - // Per-thread offset in logical coordinates of tensor - thread_offset_ = threadblock_offset + residue_offset + - ThreadMap::initial_offset(thread_id); - - // update internal pointers - Layout layout(params_.stride_); - add_pointer_offset(layout(thread_offset_)); - - compute_predicates_(false); - - set_iteration_index(0); - } - - /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id) - : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); - iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); - - iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided; - iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided; - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += int(sizeof(Element)) * pointer_offset; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - if (is_residue_tile_) { - TensorCoord residue_offset; - if (kAdvanceRank) { - residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided); - } else { - residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0); - } - - thread_offset_ -= residue_offset; - - Layout layout(params_.stride_); - add_pointer_offset(-layout(residue_offset)); - - compute_predicates_(true); - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } else { - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * tile_offset.strided(); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * tile_offset.contiguous(); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } - is_residue_tile_ = false; - } - - CUTLASS_HOST_DEVICE - AccessType *get() const { - - AccessType *ret_val = reinterpret_cast( - pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element))); - - return ret_val; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile &operator++() { - - iteration_thread_++; - - if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided) - return *this; - - iteration_thread_ = 0; - - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - pointer_ += params_.inc_strided_; - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; - - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile operator++(int) { - PredicatedTileAccessIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = enable ? 0u : predicates_[i]; - } - - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0xffffffff; - } - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = mask[i]; - } - - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] = predicates_[i]; - } - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - - int pred_idx = - iteration_thread_ + - iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + - iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; - - return pred; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator2dThreadTile { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator2dThreadTile; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), - threadblock_offset.column())) {} - - /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile operator++(int) { - PredicatedTileAccessIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIterator2dThreadTile { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIterator2dThreadTile; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))){} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator2dThreadTile operator++(int) { - PredicatedTileAccessIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h deleted file mode 100644 index 5e509a344e955438ea4eabe6806ed2ab79343d36..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +++ /dev/null @@ -1,290 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/detail/helper_macros.hpp" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Predicated tile access iterator descriptor object containing template dependent state -struct PredicatedTileAccessIteratorDesc { - - int element_size_bits = -1; - int advance_rank = -1; - layout::PitchLinearCoord threadblock_shape; - layout::PitchLinearCoord threadmap_iterations; - layout::PitchLinearCoord threadmap_delta; - - // - // Methods - // - - PredicatedTileAccessIteratorDesc() = default; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc( - int element_size_bits_, - int advance_rank_, - layout::PitchLinearCoord threadblock_shape_, - layout::PitchLinearCoord threadmap_iterations_, - layout::PitchLinearCoord threadmap_delta_ - ): - element_size_bits(element_size_bits_), - advance_rank(advance_rank_), - threadblock_shape(threadblock_shape_), - threadmap_iterations(threadmap_iterations_), - threadmap_delta(threadmap_delta_) - { - #if 0 - printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n", - element_size_bits, - advance_rank, - threadblock_shape.contiguous(), threadblock_shape.strided(), - threadmap_iterations.contiguous(), threadmap_iterations.strided(), - threadmap_delta.contiguous(), threadmap_delta.strided()); - #endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template -// dependent state -template < - typename Shape, typename Element, typename Layout, - int AdvanceRank, typename ThreadMap> - struct MakePredicatedTileAccessIteratorDesc; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -template < - typename Shape, typename Element, int AdvanceRank, - typename ThreadMap> -struct MakePredicatedTileAccessIteratorDesc < - Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc operator()() { - - return PredicatedTileAccessIteratorDesc( - sizeof_bits::value, - AdvanceRank, - {Shape::kContiguous, Shape::kStrided}, - {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, - {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} - ); -} - -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for column-major data. -template < - typename Shape, typename Element, int AdvanceRank, - typename ThreadMap> -struct MakePredicatedTileAccessIteratorDesc < - Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { - - static int const kAdvanceRank = AdvanceRank; - - using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc operator()() { - - return UnderlyingMakeOperator()(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for row-major data. -template < - typename Shape, typename Element, int AdvanceRank, - typename ThreadMap> -struct MakePredicatedTileAccessIteratorDesc < - Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { - - static int const kAdvanceRank = AdvanceRank; - - using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc operator()() { - - return UnderlyingMakeOperator()(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -template < - typename Shape, typename Element, int AdvanceRank, - typename ThreadMap, int InterleavedK> -struct MakePredicatedTileAccessIteratorDesc < - Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { - - static int const kAdvanceRank = AdvanceRank; - static int const kInterleavedK = InterleavedK; - - using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc operator()() { - - return UnderlyingMakeOperator()(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. -template < - typename Shape, typename Element, int AdvanceRank, - typename ThreadMap, int InterleavedK> -struct MakePredicatedTileAccessIteratorDesc < - Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { - - static int const kAdvanceRank = AdvanceRank; - static int const kInterleavedK = InterleavedK; - - using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc operator()() { - - return UnderlyingMakeOperator()(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Parameters struct -// - -struct PredicatedTileAccessIteratorParams { - - using Index = int32_t; - using LongIndex = int64_t; - - // - // Data members - // - /// stride of pitch-linear layout (units of Element) - LongIndex stride_ = 0; - /// amount (in byte) to increment pointer to move to next access along - /// strided dimension - LongIndex inc_strided_ = 0; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - LongIndex inc_next_ = 0; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { - CUTLASS_ASSERT(desc.element_size_bits > 0); - CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1); - - stride_ = stride; - - inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * - desc.element_size_bits / 8; - - if (desc.advance_rank) { - // advance along strided dimension - inc_advance_ = - desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; - } else { - // advance along contiguous dimension - inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; - } - - inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * - desc.threadmap_delta.strided() * LongIndex(stride_) * - desc.element_size_bits / 8; - - return Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { - return initialize(LongIndex(stride), desc); - } - - PredicatedTileAccessIteratorParams() = default; - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { - initialize(stride, desc); - } - - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { - initialize(stride, desc); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h deleted file mode 100644 index f657fe25813567b47156047f6ef023b678ac097f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h +++ /dev/null @@ -1,892 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates calculating the address and predicates to the load of tiles - from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last - "residue" tile first, with the objective of minimizing predicate mask updates - during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be - stored in registers, and integer addition is used to advance the pointer - through memory. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileAccessIteratorTriangularMatrix -/// -template -class PredicatedTileAccessIteratorTriangularMatrix; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data. -/// -template -class PredicatedTileAccessIteratorTriangularMatrix { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - using CompareOp = typename TrMatrixCompareOp::Type; - - static_assert( kFillMode == FillMode::kFull || - ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1), - "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1"); - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), - "Vectors implied by the thread map must be divisible by the access type."); - - static int const kPredicatesPerByte = 4; - static int const kPredicatesPerWord = 4 * kPredicatesPerByte; - - static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; - - /// Number of 32b words containing predicates - static int const kPredicateByteCount = - (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; - static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; - - static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; - - static_assert(kPredicateWordCount <= 4, "Too many predicates."); - - /// Predicate vector stores mask to guard accesses - using Mask = Array; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend PredicatedTileAccessIteratorTriangularMatrix; - - private: - /// stride of pitch-linear layout (units of Element) - StrideIndex stride_; - /// (true) pitch-linear layout is mapped to row-major matrix - /// (false) pitch-linear layout is mapped to column-major matrix - bool is_row_major_; - /// for vectorized access across the diagonal boundary guard condition is - /// checked for the element on the boundary - int access_diagonal_boundary_; - /// amount (in byte) to increment pointer to move to next access along - /// strided dimension - LongIndex inc_strided_; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - LongIndex inc_next_; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_; - - public: - - // Default ctor - CUTLASS_HOST_DEVICE - Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) : - stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) { - - inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * - sizeof_bits::value / 8; - - if (kAdvanceRank) { - // advance along strided dimension - inc_advance_ = - Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; - } else { - // advance along contiguous dimension - inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; - } - - inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * LongIndex(stride_) * - sizeof_bits::value / 8; - - }; - - - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - - /// Guard predicates - uint32_t predicates_[kPredicateWordCount]; - - /// Track global memory addresses on the diagonal - /// To ignore imag part for diagonal elements of hermitian matrices - uint32_t predicates_onDiag_[kPredicateWordCount]; - - /// Size of tensor - TensorCoord extent_; - - /// Initial offset for each thread - TensorCoord thread_offset_; - - /// Iteration along vectors implied by the thread map - int iteration_vector_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - private: - /// Computes predicates based on internally tracked per-thread offset. - CUTLASS_DEVICE - void compute_predicates_( - /// Extent of the matrix window - TensorCoord extent) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0u; - predicates_onDiag_[i] = 0u; - } - - CompareOp compare_op; - - CUTLASS_PRAGMA_UNROLL - for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { - - int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); - - int c = access_residual / kAccessesPerVector; - int v = access_residual % kAccessesPerVector; - - TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, - s * ThreadMap::Delta::kStrided); - - TensorCoord coord = thread_offset_ + iteration_coord; - - bool guard; - bool onDiag = false; - - guard = ((coord.strided() < extent.strided()) && - (coord.contiguous() < extent.contiguous())); - - - // guard access on the wrong side of the triagular matrix diagonal - if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) { - coord += TensorCoord{params_.access_diagonal_boundary_, 0}; - - bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_; - bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_; - - guard = guard && triagular_guard_row_major && triagular_guard_col_major; - - if (kDiagType == DiagType::kUnit) { - onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false; - } - } - - int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); - int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord; - int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord; - int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte; - int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte; - - predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag)); - - int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); - - } - - } - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : params_(params), - pointer_(reinterpret_cast(const_cast(pointer))), - extent_(extent) { - - - // Per-thread offset in logical coordinates of tensor - thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); - - // update internal pointers - Layout layout(params_.stride_); - add_pointer_offset(layout(thread_offset_)); - - compute_predicates_(extent_); - - set_iteration_index(0); - } - - /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id) - : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_vector_ = index % kAccessesPerVector; - int residual_access = index / kAccessesPerVector; - - iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; - iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += sizeof_bits::value * pointer_offset / 8; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()}; - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); - pointer_ += Shape::kStrided * tile_offset.strided(); - thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0}; - } - - compute_predicates_(extent_); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast( - pointer_ + - iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix &operator++() { - - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - - iteration_vector_ = 0; - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - pointer_ += params_.inc_strided_; - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; - - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix operator++(int) { - PredicatedTileAccessIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = enable ? 0u : predicates_[i]; - } - - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = 0xffffffff; - } - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - predicates_[i] = mask[i]; - } - - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kPredicateWordCount; ++i) { - mask[i] = predicates_[i]; - } - } - - /// Return if the address in on the diagonal - CUTLASS_HOST_DEVICE - bool getOnDiag() { - int pred_idx = - iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; - return pred; - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - - - int pred_idx = - iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); - - int word_idx = pred_idx / kPredicatesPerWord; - int residual = pred_idx % kPredicatesPerWord; - int byte_idx = residual / kPredicatesPerByte; - int bit_idx = residual % kPredicatesPerByte; - - bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; - return pred; - - - //return true; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIteratorTriangularMatrix { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, - kSideMode, kFillMode, kDiagType, AccessType>; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - static int const kAccessDiagonalBoundary = - (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIteratorTriangularMatrix; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){}; - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), - threadblock_offset.column())) {} - - /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix operator++(int) { - PredicatedTileAccessIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Return if the address in on the diagonal - CUTLASS_HOST_DEVICE - bool getOnDiag() { - return iterator_.getOnDiag(); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileAccessIteratorTriangularMatrix { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - using AccessType = AccessType_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< - layout::PitchLinearShape, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, - kSideMode, kFillMode, kDiagType, AccessType>; - - static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; - - static int const kAccessDiagonalBoundary = - (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileAccessIteratorTriangularMatrix; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){}; - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - ///< Precomputed parameters object - Params const ¶ms, - ///< Pointer to start of tensor - Pointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorTriangularMatrix operator++(int) { - PredicatedTileAccessIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Return if the address in on the diagonal - CUTLASS_HOST_DEVICE - bool getOnDiag() { - return iterator_.getOnDiag(); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h deleted file mode 100644 index 43c4cbd1a5758e0288f82babbe7043d22f83c009..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h +++ /dev/null @@ -1,1887 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses. The first tile this - iterator visits maybe partial, then the remaining tiles are complete. So, we - only need to compute the predicates twice, once before the first tile and - once for the remaining full tiles which can share the same predicates. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/arch/memory.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileIterator -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -/// Regular tile iterator using a precomputed control structure to minimize register liveness -/// and integer arithmetic. -/// -/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -/// -/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -/// Subsequently, they are assumed to be immutable. -/// -/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -/// -/// Visitation order is intended to first visit a "residual" tile that may be partially full in -/// both the advance dimension and the steady-state dimension. This is assumed to be the last -/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -/// accesses may be performed without updating internal predicates and are efficient in terms of -/// live register state and pointer arithmetic instructions. -/// -/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -/// outside any looping structure to minimize integer arithmetic. -/// -/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -/// the iterator. -/// -/// -/// Example: -/// -/// An efficient pipeline structure may be constructed as follows: -/// -// template -// __global__ void kernel( -// typename Iterator::Params params, -// typename Iterator::Element *ptr, -// TensorCoord extent) { -// -// typename Iterator::Fragment fragment; -// -// TensorCoord threadblock_offset(0, 0); -// -// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -// -// -// fragment = *iter; // load "residue" tile first -// ++iter; // advance to first "steady state" tile and update internal masks -// -// -// #pragma unroll -// for (int i = Remaining - 1; i >= 0; --i) { -// -// f(fragment); -// -// if (!i) { -// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -// } -// -// fragment = *iter; // load tile during "steady state" phase -// ++iter; // advance to next tile - lightweight due to steady-state masks -// } -// } -// -// void host(TensorView view) { -// -// using Iterator = transform::threadblock::PredicatedTileIterator; -// -// typename Iterator::Params params(view.layout()); -// -// kernel(params, view.data()); -// } -/// -/// -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - int AccessSize = ThreadMap::kElementsPerAccess, - bool Gather = false, - typename PermuteLayout = layout::NoPermute -> -class PredicatedTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - /// Type used for internal memory accesses - using AccessType = AlignedArray::value / 8)>; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = - PredicatedTileAccessIterator; - - static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename TileAccessIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - using Base = typename TileAccessIterator::Params::Base; - - friend PredicatedTileIterator; - - private: - /// Parameters object - typename TileAccessIterator::Params params_; - - public: - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) {} - - /// Default constructor - Params() = default; - - CUTLASS_HOST_DEVICE - Params(Base const &base) - : params_(base) {} - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - /// Gather indices - int const *indices = nullptr) - : address_iterator_(params.params_, pointer, extent, thread_id, - threadblock_offset, indices) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - if (kAdvanceRank) - address_iterator_.add_tile_offset({0, 1}); - else - address_iterator_.add_tile_offset({1, 0}); - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { address_iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } - - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - address_iterator_.set_iteration_index(idx); - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, address_iterator_.valid()); - - ++address_iterator_; - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_byte_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - if (address_iterator_.valid()) { - *access_ptr = frag_ptr[idx]; - } - ++address_iterator_; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize, - bool Gather, - typename PermuteLayout -> -class PredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap, - AccessSize, - Gather, - PermuteLayout - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) - {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), - indices) - { } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize, - bool Gather, - typename PermuteLayout -> -class PredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - AccessSize, - Gather, - PermuteLayout - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - - }; - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< Gather indices - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), - indices - ) { } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for affine rank-2 data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileIterator, AdvanceRank, - ThreadMap_, AccessSize, false> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRankN<2>; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - /// Type used for internal memory accesses - using AccessType = AlignedArray::value / 8)>; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = - PredicatedTileAccessIterator; - - static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename TileAccessIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - - friend PredicatedTileIterator; - - private: - /// Parameters object - typename TileAccessIterator::Params params_; - - public: - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) {} - - /// Default constructor - Params() = default; - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : address_iterator_(params.params_, pointer, extent, thread_id, - threadblock_offset) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - if (kAdvanceRank) - address_iterator_.add_tile_offset(make_Coord(0, 1)); - else - address_iterator_.add_tile_offset(make_Coord(1, 0)); - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { address_iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } - - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - address_iterator_.set_iteration_index(idx); - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, address_iterator_.valid()); - - ++address_iterator_; - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_byte_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - if (address_iterator_.valid()) { - *access_ptr = frag_ptr[idx]; - } - ++address_iterator_; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize -> -class PredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRank2ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - // Map to the underlying AffineRankN<2> layout - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::AffineRankN<2>, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given an AffineRankN<2> tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) - {} - }; - -private: - - // - // Data members - // - - /// Underlying AffineRankN<2> tile iterator - UnderlyingIterator iterator_; - -public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) - ) { } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int AccessSize -> -class PredicatedTileIterator { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::AffineRank2RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - // Map to the underlying AffineRankN<2> layout - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, - layout::AffineRankN<2>, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given an AffineRankN<2> tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} - }; - - -private: - - // - // Data members - // - - /// Underlying AffineRankN<2> tile iterator - UnderlyingIterator iterator_; - -public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) - ) { } - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for interleaved data. It is mapped -/// to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// - -template -class PredicatedTileIterator, - AdvanceRank, ThreadMap_, AccessSize, false> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::ColumnMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; - - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.row() * kInterleavedK, - extent.column() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.row() * kInterleavedK, - threadblock_offset.column() / kInterleavedK)) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator for interleaved-32 data. It is -/// mapped to the congruous layout. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileIterator, - AdvanceRank, ThreadMap_, AccessSize, false> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - static int const kInterleavedK = InterleavedK; - using Layout = layout::RowMajorInterleaved; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator< - layout::PitchLinearShape, - Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; - - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - friend PredicatedTileIterator; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - /// Default constructor - Params() = default; - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) - : params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - - /// Default constructor - PredicatedTileIterator() = default; - - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : iterator_(params.params_, pointer, - layout::PitchLinearCoord(extent.column() * kInterleavedK, - extent.row() / kInterleavedK), - thread_id, - layout::PitchLinearCoord( - threadblock_offset.column() * kInterleavedK, - threadblock_offset.row() / kInterleavedK)) {} - - /// Construct a PredicatedTileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIterator(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator operator++(int) { - PredicatedTileIterator self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { iterator_.get_mask(mask); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h deleted file mode 100644 index cbe48df6e7dc1c66c9e55b8eab14aa1fb53bc14b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h +++ /dev/null @@ -1,787 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile - first, with the objective of minimizing predicate mask updates during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h" -#include "cutlass/transform/thread/transpose.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileIterator2dThreadTile -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -/// Regular tile iterator using a precomputed control structure to minimize register liveness -/// and integer arithmetic. -/// -/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -/// -/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -/// Subsequently, they are assumed to be immutable. -/// -/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -/// -/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -/// both the advance dimension and the steady-state dimension. This is assumed to be the last -/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -/// accesses may be performed without updating internal predicates and are efficient in terms of -/// live register state and pointer arithmetic instructions. -/// -/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -/// outside any looping structure to minimize integer arithmetic. -/// -/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -/// the iterator. -/// -/// -/// Example: -/// -/// An efficient pipeline structure may be constructed as follows: -/// -// template -// __global__ void kernel( -// typename Iterator::Params params, -// typename Iterator::Element *ptr, -// TensorCoord extent) { -// -// typename Iterator::Fragment fragment; -// -// TensorCoord threadblock_offset(0, 0); -// -// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -// -// -// fragment = *iter; // load "residue" tile first -// ++iter; // advance to first "steady state" tile and update internal masks -// -// -// #pragma unroll -// for (int i = Remaining - 1; i >= 0; --i) { -// -// f(fragment); -// -// if (!i) { -// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -// } -// -// fragment = *iter; // load tile during "steady state" phase -// ++iter; // advance to next tile - lightweight due to steady-state masks -// } -// } -// -// void host(TensorView view) { -// -// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile; -// -// typename Iterator::Params params(view.layout()); -// -// kernel(params, view.data()); -// } -/// -/// -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - bool Transpose = false -> -class PredicatedTileIterator2dThreadTile; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileIterator2dThreadTile { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - /// Type used for internal memory accesses - /// extra set of parenthesis is needed for VS compiler - struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits::value / - 8)) AccessType { - - Array storage; - - static int const kElements = ThreadMap::kElementsPerAccess; - }; - - /// Optionally this fragment can be 4x4 transposed - using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>; - static bool const transpose = Transpose_; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = - PredicatedTileAccessIterator2dThreadTile; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename TileAccessIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - using Base = typename TileAccessIterator::Params::Base; - - friend PredicatedTileIterator2dThreadTile; - - private: - /// Parameters object - typename TileAccessIterator::Params params_; - - public: - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) { } - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) - : params_(base) {} - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset, - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ) - : address_iterator_(params.params_, pointer, extent, thread_id, - threadblock_offset) {} - - /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile &operator++() { - if (kAdvanceRank) - address_iterator_.add_tile_offset({0, 1}); - else - address_iterator_.add_tile_offset({1, 0}); - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile operator++(int) { - PredicatedTileIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { address_iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ - - int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ - s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; - - address_iterator_.set_iteration_index(access_idx); - if (address_iterator_.valid()) { - - frag_ptr[access_idx] = - *(address_iterator_.get() + pointer_offset); - } - - ++address_iterator_; - } - } - } - - if (transpose) { - Transform t; - t.transform(frag, frag); - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ - - int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ - s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; - - address_iterator_.set_iteration_index(access_idx); - if (address_iterator_.valid()) { - *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; - } - ++address_iterator_; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - bool Transpose_ -> -class PredicatedTileIterator2dThreadTile { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static bool const Transpose = Transpose_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator2dThreadTile< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap, - Transpose - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator2dThreadTile; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) - ) { } - - /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile operator++(int) { - PredicatedTileIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - bool Transpose_ -> -class PredicatedTileIterator2dThreadTile { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static bool const Transpose = Transpose_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIterator2dThreadTile< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - Transpose - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIterator2dThreadTile; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { } - - CUTLASS_HOST_DEVICE - Params(typename UnderlyingIterator::Params::Base const &base) - : params_(base) {} - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset, ///< Initial offset of threadblock - int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) - ) { } - - /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIterator2dThreadTile operator++(int) { - PredicatedTileIterator2dThreadTile self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h deleted file mode 100644 index 9bf5e8586675c11bb52e2db5346ff19f489461af..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h +++ /dev/null @@ -1,818 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile - first, with the objective of minimizing predicate mask updates during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/arch/memory.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedTileIteratorTriangularMatrix -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -/// Regular tile iterator using a precomputed control structure to minimize register liveness -/// and integer arithmetic. -/// -/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -/// -/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -/// Subsequently, they are assumed to be immutable. -/// -/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -/// -/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -/// both the advance dimension and the steady-state dimension. This is assumed to be the last -/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -/// accesses may be performed without updating internal predicates and are efficient in terms of -/// live register state and pointer arithmetic instructions. -/// -/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -/// outside any looping structure to minimize integer arithmetic. -/// -/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -/// the iterator. -/// -/// -/// Example: -/// -/// An efficient pipeline structure may be constructed as follows: -/// -// template -// __global__ void kernel( -// typename Iterator::Params params, -// typename Iterator::Element *ptr, -// TensorCoord extent) { -// -// typename Iterator::Fragment fragment; -// -// TensorCoord threadblock_offset(0, 0); -// -// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -// -// -// fragment = *iter; // load "residue" tile first -// ++iter; // advance to first "steady state" tile and update internal masks -// -// -// #pragma unroll -// for (int i = Remaining - 1; i >= 0; --i) { -// -// f(fragment); -// -// if (!i) { -// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -// } -// -// fragment = *iter; // load tile during "steady state" phase -// ++iter; // advance to next tile - lightweight due to steady-state masks -// } -// } -// -// void host(TensorView view) { -// -// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix; -// -// typename Iterator::Params params(view.layout()); -// -// kernel(params, view.data()); -// } -/// -/// -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - SideMode kSideMode, - FillMode kFillMode, - DiagType kDiagType, - int AccessSize = ThreadMap::kElementsPerAccess -> -class PredicatedTileIteratorTriangularMatrix; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template -class PredicatedTileIteratorTriangularMatrix { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - /// Type used for internal memory accesses - using AccessType = AlignedArray::value / 8)>; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = - PredicatedTileAccessIteratorTriangularMatrix; - - static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename TileAccessIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend PredicatedTileIteratorTriangularMatrix; - - private: - /// Parameters object - typename TileAccessIterator::Params params_; - - public: - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) { } - - CUTLASS_HOST_DEVICE - Params() { } - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - /// Precomputed parameters object - Params const ¶ms, - /// Pointer to start of tensor - Pointer pointer, - /// Extent of tensor - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : address_iterator_(params.params_, pointer, extent, thread_id, - threadblock_offset) {} - - /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ) - : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, - make_Coord(0, 0)) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix &operator++() { - if (kAdvanceRank) - address_iterator_.add_tile_offset({0, 1}); - else - address_iterator_.add_tile_offset({1, 0}); - - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix operator++(int) { - PredicatedTileIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { address_iterator_.enable_mask(); } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } - - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - address_iterator_.set_iteration_index(idx); - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - cutlass::arch::global_load( - frag_ptr[idx], access_ptr, address_iterator_.valid()); - - ++address_iterator_; - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_byte_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kAccessesPerVector; ++v) { - - int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - if (address_iterator_.valid()) { - *access_ptr = frag_ptr[idx]; - } - ++address_iterator_; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - SideMode kSideMode, - FillMode kFillMode, - DiagType kDiagType, - int AccessSize -> -class PredicatedTileIteratorTriangularMatrix { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap, - kSideMode, - kFillMode, - kDiagType, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIteratorTriangularMatrix; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { - - } - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset ///< Initial offset of threadblock - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.row(), extent.column()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) - ) { } - - /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix operator++(int) { - PredicatedTileIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept | -/// MaskedTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - SideMode kSideMode, - FillMode kFillMode, - DiagType kDiagType, - int AccessSize -> -class PredicatedTileIteratorTriangularMatrix { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - kSideMode, - kFillMode, - kDiagType, - AccessSize - >; - - using AccessType = typename UnderlyingIterator::AccessType; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array; - - /// Predicate vector stores mask to guard accesses - using Mask = typename UnderlyingIterator::Mask; - - /// Parameters object is precomputed state and is host-constructible - class Params { - private: - - friend PredicatedTileIteratorTriangularMatrix; - - /// Parameters object - typename UnderlyingIterator::Params params_; - - public: - - CUTLASS_HOST_DEVICE - Params() { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { - - }; - }; - - -private: - - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - -public: - - /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - TensorCoord const &threadblock_offset ///< Initial offset of threadblock - ): - iterator_( - params.params_, - pointer, - layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, - layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) - ) { } - - /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix( - Params const ¶ms, ///< Precomputed parameters object - Pointer pointer, ///< Pointer to start of tensor - TensorCoord extent, ///< Extent of tensor - int thread_id ///< ID of each participating thread - ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the iterator's - /// internal pointer is reverted to the first "steady state" tile. Subsequent calls - /// are lightweight and must only update the internal pointer. - CUTLASS_HOST_DEVICE - PredicatedTileIteratorTriangularMatrix operator++(int) { - PredicatedTileIteratorTriangularMatrix self(*this); - operator++(); - return self; - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void clear_mask(bool enable = true) { - iterator_.clear_mask(enable); - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE - void enable_mask() { - iterator_.enable_mask(); - } - - /// Sets the predicate mask, overriding value stored in predicate iterator - CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { - iterator_.set_mask(mask); - } - - /// Gets the mask - CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { - iterator_.get_mask(mask); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { - iterator_.load_with_byte_offset(frag, byte_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { - iterator_.store_with_byte_offset(frag, byte_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h deleted file mode 100644 index df551c13f52834bfa6258104f99c7ed008342279..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h +++ /dev/null @@ -1,417 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates implementing computing the addresses of loading small - vectors from the global memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// PredicatedVectorAccessIterator -/// -template < - /// Shape of the vector accessed by the entire threadblock - typename Shape, - /// Shape of the vector accessed by the warp - typename WarpShape, - /// Type of Element - typename Element, - /// Layout of the vector - typename Layout, - /// Number of elements for each access - int ElementsPerAccess, - /// Support residual tile - bool EnableResidualAccess = false -> -class PredicatedVectorAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Vector access iterator specialized for vectors, e.g. scale and bias -/// Thread arrangements are for TensorOps -/// -template < - typename Shape_, - typename WarpShape_, - typename Element_, - int ElementsPerAccess, - bool EnableResidualAccess -> -class PredicatedVectorAccessIterator < - Shape_, - WarpShape_, - Element_, - layout::PitchLinear, - ElementsPerAccess, - EnableResidualAccess -> { - public: - - using Shape = Shape_; - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - -// static int const kElementsPerAccess = 128 / sizeof_bits::value; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kThreads = 32; - static int const kRowsPerIteration = 8; - static int const kThreadsPerRow = kThreads / kRowsPerIteration; - static int const kThreadsPerRowMask = 0x3; - static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess); - static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided; - - using AccessType = AlignedArray; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Internal pointer to first access of tile - BytePointer pointer_; - - /// Extent of tensor - TensorCoord extent_; - - /// pointer offset of each thread - TensorCoord thread_offset_; - - /// iteration index - LongIndex iteration_; - - /// residual access - bool is_residual_; - - /// residual offset of each thread - TensorCoord residual_offset_; - - public: - /// Constructs a vector access iterator - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator( - /// Pointer to the start of the vector - ConstPointer pointer, - /// Extent of vector - TensorCoord extent, - /// ID of each participating thread - int thread_id, - /// ID of each participating warp - int warp_id, - /// Initial offset of threadblock - TensorCoord const &threadblock_offset) - : pointer_(reinterpret_cast( - const_cast(pointer))), - extent_(extent), - is_residual_(false) { - - - int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; - - // Per-thread offset in logical coordinates of tensor - - thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) + - TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0); - - set_iteration_index(0); - - if(EnableResidualAccess) { - // compute residual offset - typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous; - if (residual_size) { - is_residual_ = true; - residual_offset_ = make_Coord(residual_size, 0); - } - } - } - - /// Construct a PredicatedVectorAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator( - /// Pointer to start of vector - ConstPointer pointer, - /// Extent of vector - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - /// ID of each participating warp - int warp_id) - : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, - make_Coord(0, 0)) {} - - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_ = index; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - - thread_offset_ = - thread_offset_ + - TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - return reinterpret_cast( - pointer_ + - ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess) - * sizeof_bits::value / 8)); - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator &operator++() { - ++iteration_; - if(iteration_ >= kIterations) - iteration_ = 0; - - return *this; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - void advance() { - if(EnableResidualAccess && is_residual_) { - is_residual_ = false; - thread_offset_ += residual_offset_; - } - else - add_tile_offset(TensorCoord(1, 0)); - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator operator++(int) { - PredicatedVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return ((thread_offset_.contiguous() + - iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous()); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedVectorAccessIterator for row-major data. -/// -template < - typename Shape_, - typename WarpShape_, - typename Element_, - int ElementsPerAccess, - bool EnableResidualAccess -> -class PredicatedVectorAccessIterator< - Shape_, - WarpShape_, - Element_, - layout::RowMajor, - ElementsPerAccess, - EnableResidualAccess -> { - public: - - using Shape = Shape_; - using WarpShape = WarpShape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - - using ConstPointer = const Element *; - using NonConstPointer = typename platform::remove_const::type *; - - using UnderlyingIterator = PredicatedVectorAccessIterator< - layout::PitchLinearShape, - layout::PitchLinearShape, - Element, - layout::PitchLinear, - ElementsPerAccess, - EnableResidualAccess>; - - using AccessType = typename UnderlyingIterator::AccessType; - static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; - static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration; - static int const kThreads = UnderlyingIterator::kThreads; - static int const kIterations = UnderlyingIterator::kIterations; - - private: - // - // Data members - // - - /// Underlying pitch-linear tile iterator - UnderlyingIterator iterator_; - - public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator( - ///< Pointer to the start of the vector - ConstPointer pointer, - ///< Extent of tensor - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< ID of each participating warp - int warp_id, - ///< Initial offset of threadblock - TensorCoord const &threadblock_offset) - : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()), - thread_id, warp_id, - layout::PitchLinearCoord(threadblock_offset.column(), - threadblock_offset.row())) {} - - /// Construct a PredicatedVectorAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator( - ConstPointer pointer, ///< Pointer to the start of the vector - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int warp_id ///< ID of each participating warp - ) - : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, - make_Coord(0, 0)) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Advances an iterator along logical dimensions of matrix in units of whole - /// tiles - CUTLASS_HOST_DEVICE - void add_tile_offset(TensorCoord const &tile_offset) { - iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - /// - /// The first time this method is called, predicates are updated, and the - /// iterator's internal pointer is reverted to the first "steady state" tile. - /// Subsequent calls are lightweight and must only update the internal - /// pointer. - CUTLASS_HOST_DEVICE - PredicatedVectorAccessIterator operator++(int) { - PredicatedVectorAccessIterator self(*this); - operator++(); - return self; - } - - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - void advance() { - iterator_.advance(); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() { - return iterator_.valid(); - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h deleted file mode 100644 index 1aae46988418c72a9322b7e6b47e1dfe4fadff8d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h +++ /dev/null @@ -1,253 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Templates implementing computing the addresses of storing of small - scale and bias vectors in the shared memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// RegularScaleBiasVectorAccessIterator -/// -template -class RegularScaleBiasVectorAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularScaleBiasVectorAccessIterator { - public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - /// Element type per access - static int const kElementsPerAccess = 128 / sizeof_bits::value; - static int const kThreads = Shape::kContiguous / kElementsPerAccess; - using AccessType = Array; - - private: - // - // Data members - // - - /// Internal pointer - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator( - TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias - ///< vector - int thread_id ///< ID of each participating thread - ) - : byte_offset_(0) { - // Per-thread offset in logical coordinates of tensor - int thread_offset = thread_id * kElementsPerAccess; - - // initialize pointer - pointer_ = - reinterpret_cast(scale_bias_ref.data() + thread_offset); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_DEVICE - AccessType *get() const { - - char *access_byte_ptr = - reinterpret_cast(pointer_); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator &operator++() { return *this; } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator operator++(int) { - RegularScaleBiasVectorAccessIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset in the unit of tile. - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - // Multiply by 2 because we store scale and bias belong to the same stage - // next to each other. - add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for row major layouts -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularScaleBiasVectorAccessIterator< - Shape_, Element_, - layout::RowMajor> { - public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - /// Underlying iterator type - using UnderlyingIterator = RegularScaleBiasVectorAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator( - TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias - ///< vector - int thread_id ///< ID of each participating thread - ) - : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) { - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularScaleBiasVectorAccessIterator operator++(int) { - RegularScaleBiasVectorAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h deleted file mode 100644 index cfb491b5a4b5f4e1b757f99110f6a9fd28675088..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h +++ /dev/null @@ -1,58 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing the address computation of storing of tiles - from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template ::value* ThreadMap::kElementsPerAccess / 8> -class RegularTileAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h deleted file mode 100644 index adda9339b87865799c56baba4c3f8df580e26ac5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h +++ /dev/null @@ -1,408 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing computing the addresses of storing of tiles - from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::PitchLinear, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset in the unit of tile. - /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. - /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. - /// For row major A operand, k dimension is contiguous dimension; - /// For col major A operand, k dimension is strided dimension; - /// For row major B operand, k dimension is strided dimension; - /// For col major B operand, k dimension is contiguous dimension. - /// Below two classes map col/row major to the pitch linear coordinates used - /// in this base class. - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset(coord.contiguous() * Shape::kContiguous + - coord.strided() * Shape::kStrided * stride_ * - ThreadMap::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for column major layouts -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajor, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for row major layouts -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::RowMajor, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h deleted file mode 100644 index 71c89686a71995b45f9d4cf0fd1f0fba12ca7d8a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h +++ /dev/null @@ -1,587 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing computing the addresses of storing of tiles - from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - - -//////////////////////////////////////////////////////////////////////////////// - -template ::value* ThreadMap::kElementsPerAccess / 8 - > -class RegularTileAccessIteratorDirectConv; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIteratorDirectConv< - Shape_, Element_, - layout::PitchLinear, - AdvanceRank, ThreadMap_, false, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_num(int num) { - //Do nothing - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv operator++(int) { - RegularTileAccessIteratorDirectConv prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset in the unit of tile. - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset(coord.contiguous() * Shape::kContiguous + - coord.strided() * ThreadMap::Iterations::kStrided * - ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIteratorDirectConv< - Shape_, Element_, - layout::PitchLinear, - AdvanceRank, ThreadMap_,true, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - /// Total iterattions in the strided dimension: Dynamic value - int total_iteration_strided_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_num(int num) { - total_iteration_strided_ = num; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < total_iteration_strided_) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv operator++(int) { - RegularTileAccessIteratorDirectConv prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset in the unit of tile. - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset(coord.contiguous() * Shape::kContiguous + - coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * - ThreadMap::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for column major layouts -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIteratorDirectConv< - Shape_, Element_, - layout::ColumnMajor, - AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIteratorDirectConv< - layout::PitchLinearShape, Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap_, - Dynamic_iterations>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_num(int num) { - iterator_.set_iteration_num(num); - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv operator++(int) { - RegularTileAccessIteratorDirectConv prev(*this); - ++iterator_; - - return prev; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for row major layouts -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIteratorDirectConv< - Shape_, Element_, - layout::RowMajor, - AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIteratorDirectConv< - layout::PitchLinearShape, Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap_, - Dynamic_iterations>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_num(int num) { - iterator_.set_iteration_num(num); - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIteratorDirectConv operator++(int) { - RegularTileAccessIteratorDirectConv prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h deleted file mode 100644 index e172447fa96b02e11246f5f397911841c52eff4c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h +++ /dev/null @@ -1,821 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing computing the addresses of storing of tiles - from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandCongruous::value, - Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - static int const kCrosswise = Crosswise; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - - ///< Number of pointers - static int const kPointerCount = - (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); - }; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_[Detail::kPointerCount]; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), - byte_offset_(0) { - layout::PitchLinearCoord thread_offset_base = - ThreadMap::initial_offset(thread_id); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Detail::kPointerCount; ++i) { - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = - thread_offset_base + - layout::PitchLinearCoord{ - 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; - - // initialize pointer - pointer_[i] = reinterpret_cast( - ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - AccessType *access_ptr = pointer_[iteration_strided_ & 1]; - int stride_idx = (iteration_strided_ & ~1); - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset(coord.contiguous() * Shape::kContiguous * Layout::kFactor + - coord.strided() * Shape::kStrided * stride_ * - Layout::kElementsPerAccess / Layout::kFactor); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::RowMajorTensorOpMultiplicandCongruous::value, - Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for crosswise arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - static int const kCrosswise = Crosswise; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise), - "kCrosswise is the smallest unit in the contiguous dimension " - "for shared memory swizzling."); - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - - /// Number of pointers - /// - /// Note:TN kblock32 layouts only needs 1 pointer, but strangely - /// reducing pointer count hurts perfomrnace - static int const kPointerCount = - (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); - }; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Total number of sections. The memory is divided into stages. One stage - /// can store one tile. Stage is divided into sections. Interleaved layout - /// can have multiple sections in a stage. The rest layout only has one section - /// in a stage. - int sections_; - - /// Sections that a stage has - int sections_per_stage_; - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_[Detail::kPointerCount]; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : sections_(ref.stride(0) / kCrosswise), - sections_per_stage_(Shape::kContiguous / kCrosswise), - // stride_ = kCrosswise x sections_ x kFactor - stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), - byte_offset_(0) { - layout::PitchLinearCoord thread_offset_base = - ThreadMap::initial_offset(thread_id); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Detail::kPointerCount; ++i) { - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = - thread_offset_base + - layout::PitchLinearCoord{ - 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; - // initialize pointer - pointer_[i] = reinterpret_cast(ref.data()) + - ref.offset(thread_offset_in_threadblock_tile) / - Layout::kElementsPerAccess; - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - AccessType *access_ptr = pointer_[iteration_strided_ & 1]; - int stride_idx = (iteration_strided_ & ~1); - - int access_offset = - stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + - // kCrosswise elements in the contiguous dimension would span to a - // shared memory cache line. - iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) * - Layout::TileShape::kContiguous; - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) - // which means we enter the next section. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ * - ThreadMap::kElementsPerAccess / sections_ + - coord.strided() * Shape::kStrided * stride_ * - Layout::kElementsPerAccess / Layout::kFactor); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h deleted file mode 100644 index b55f841eee2e09aec8af5c8ec945a1997705c9f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h +++ /dev/null @@ -1,1532 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing computing the addresses of storing of tiles - from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::TensorOpMultiplicandCongruous64b, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorOpMultiplicandCongruous64b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - static_assert(ThreadMap::kThreads / 32 > 1, - "This tile iterator requires at least two warps."); - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 64; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 64b"); - - ///< Number of pointers - static int const kPointerCount = 1; - }; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - - RegularTileAccessIterator prev(*this); - - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - - add_pointer_offset( - coord.contiguous() * Shape::kContiguous + - coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCongruous64b, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous64b, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous64b, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for crosswise arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::TensorOpMultiplicand64bCrosswise, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorOpMultiplicand64bCrosswise; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - static_assert(ThreadMap::kThreads / 32 > 1, - "This tile iterator requires at least two warps."); - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 64; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 64b"); - - ///< Number of pointers - two pointers are needed if making more than 4 iterations along - ///< strided dimension - static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); - }; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_[Detail::kPointerCount]; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; - - // initialize pointer - pointer_ = reinterpret_cast(ref.data()); - - byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); - - if (Detail::kPointerCount == 2) { - byte_offset_[1] = byte_offset_[0] ^ 8; - } - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; - } - - /// Returns a pointer - CUTLASS_DEVICE - AccessType *get() const { - - // Map the logical contiguous and strided access to the internal swizzled structure. - int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_; - - char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); - - int byte_offset; - - // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) - // in the strided dimension - if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { - byte_offset = byte_offset_[1]; - } - else { - byte_offset = byte_offset_[0]; - } - - return reinterpret_cast(access_byte_ptr + byte_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - - RegularTileAccessIterator prev(*this); - - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - - add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicand64bCrosswise, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicand64bCrosswise, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicand64bCrosswise, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::TensorOpMultiplicandCongruous128b, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorOpMultiplicandCongruous128b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - static_assert(ThreadMap::kThreads / 32 > 1, - "This tile iterator requires at least two warps."); - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128b"); - - ///< Number of pointers - static int const kPointerCount = 1; - }; - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - - RegularTileAccessIterator prev(*this); - - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - - add_pointer_offset( - coord.contiguous() * Shape::kContiguous + - coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCongruous128b, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous128b, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous128b, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::TensorOpMultiplicandCrosswise128x4, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::TensorOpMultiplicandCrosswise128x4; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - static_assert(ThreadMap::kThreads / 32 > 1, - "This tile iterator requires at least two warps."); - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * - ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128b"); - - ///< Number of pointers - static int const kPointerCount = 1; - }; - - - static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); - - /// Element type per access - using AccessType = Array; - - private: - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType *pointer_; - - /// Internal byte offset - Index byte_offset_; - - /// Iteration in the contiguous dimension - int iteration_contiguous_; - - /// Iteration in the strided dimension - int iteration_strided_; - - public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - stride_(ref.stride(0) / Layout::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; - - // initialize pointer - pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - - set_iteration_index(0); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - - AccessType *access_ptr = pointer_; - - int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); - int offset_s = (iteration_strided_ / 2) * 8; - - int access_offset = offset_c * stride_ + offset_s; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - - return reinterpret_cast(access_byte_ptr + byte_offset_); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) - return *this; - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - - RegularTileAccessIterator prev(*this); - - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - - add_pointer_offset( - coord.contiguous() * Shape::kContiguous * stride_ + - coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise128x4, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileAccessIterator { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileAccessIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise128x4, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - using AccessType = typename UnderlyingIterator::AccessType; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileAccessIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): - iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { iterator_.set_iteration_index(index); } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast(iterator_.get()); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileAccessIterator operator++(int) { - RegularTileAccessIterator prev(*this); - ++iterator_; - - return prev; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h deleted file mode 100644 index be07e43f6f45132f79d95afb95714c4392149b66..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h +++ /dev/null @@ -1,62 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -> -class RegularTileIterator; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h deleted file mode 100644 index 6c186ce3fe0650c3f8927d84f1983916d9d1867f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ /dev/null @@ -1,552 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile - first, with the objective of minimizing predicate mask updates during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels -/// and sparse tensor core meta data. -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - - using AccessType = AlignedArray; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the contiguous or strided dimensions."); - -private: - - // - // Types - // - - // - // Data members - // - - /// Pointer to memory - uint8_t *pointer_; - - /// Stride quantity - StrideIndex stride_; - - /// Amount to increment pointer along strided dimension - Index increment_strided_; - - /// Amount to advance pointer between tiles - Index increment_advance_; - -public: - - CUTLASS_DEVICE - RegularTileIterator(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } - - CUTLASS_DEVICE - RegularTileIterator( - TensorRef const &ref, - int thread_idx - ): - pointer_(reinterpret_cast(ref.data()) + (ref.offset(ThreadMap::initial_offset(thread_idx)) * sizeof_bits::value / 8)) { - - stride_ = ref.stride()[0]; - increment_strided_ = (ref.stride()[0] * sizeof_bits::value) * ThreadMap::Delta::kStrided / 8; - - increment_advance_ = - (kAdvanceRank == 0 ? - Shape::kContiguous * sizeof_bits::value / 8 : - Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8)); - } - - /// Loads a fragment - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType const *access_ptr = reinterpret_cast(byte_pointer); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess]; - } - - if (s + 1 < ThreadMap::Iterations::kStrided) { - byte_pointer += increment_strided_; - } - } - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - load_with_pointer_offset( - frag, - tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_ - ); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - - AccessType const *frag_ptr = reinterpret_cast(&frag); - uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = reinterpret_cast(byte_pointer); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int idx = c + s * ThreadMap::Iterations::kContiguous; - access_ptr[c * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess] = frag_ptr[idx]; - } - - if (s + 1 < ThreadMap::Iterations::kStrided) { - byte_pointer += increment_strided_; - } - } - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - store_with_pointer_offset( - frag, - tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ - ); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - pointer_ += increment_advance_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator--() { - pointer_ -= increment_advance_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset; - } - - /// Adds a tile offset in the unit of tile. - /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. - /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. - /// For row major A operand, k dimension is contiguous dimension; - /// For col major A operand, k dimension is strided dimension; - /// For row major B operand, k dimension is strided dimension; - /// For col major B operand, k dimension is contiguous dimension. - /// Below two classes map col/row major to the pitch linear coordinates used - /// in this base class. - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - int offset = sizeof_bits::value * - (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; - add_pointer_offset(offset); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { -#if 0 - AccessType *access_ptr = pointer_[iteration_strided_ & 1]; - int stride_idx = (iteration_strided_ & ~1); - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + - iteration_contiguous_ * ThreadMap::Delta::kContiguous / - ThreadMap::kElementsPerAccess; - - char *access_byte_ptr = - reinterpret_cast(access_ptr + access_offset); - return reinterpret_cast(access_byte_ptr + byte_offset_); -#endif - return reinterpret_cast(pointer_); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Regular tile iterator specialized for row major -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - - using Underlying = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - kAlignment - >; - - using AccessType = typename Underlying::AccessType; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the row or column dimensions."); - -private: - - Underlying iterator_; - -public: - - CUTLASS_DEVICE - RegularTileIterator() { } - - CUTLASS_DEVICE - RegularTileIterator( - TensorRef const &ref, - int thread_idx - ): - iterator_({ref.data(), ref.stride()}, thread_idx) { - - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - iterator_.load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - iterator_.store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator--() { - --iterator_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return iterator_.get(); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Regular tile iterator specialized for pitch-linear -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajor; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - - using Underlying = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap - >; - - using AccessType = typename Underlying::AccessType; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the row or column dimensions."); - -private: - - Underlying iterator_; - -public: - - CUTLASS_DEVICE - RegularTileIterator() { } - - CUTLASS_DEVICE - RegularTileIterator( - TensorRef const &ref, - int thread_idx - ): - iterator_({ref.data(), ref.stride()}, thread_idx) { - - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - iterator_.load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - iterator_.store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator &operator--() { - --iterator_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int index) { - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return iterator_.get(); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h deleted file mode 100644 index 5ed2e7fdd08ceafe772c97ab90f915c2268cabbb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h +++ /dev/null @@ -1,509 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile - first, with the objective of minimizing predicate mask updates during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -> -class RegularTileIterator2dThreadTile; - - -/// Regular tile iterator specialized for pitch-linear + 2d thread-tiled threadmapping -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator2dThreadTile { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::PitchLinear; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the contiguous or strided dimensions."); - -private: - - // - // Types - // - - using AccessType = AlignedArray; - - // - // Data members - // - - /// Pointer to memory - uint8_t *pointer_; - - /// Stride quantity - StrideIndex stride_; - - /// Amount to increment pointer along strided dimension - LongIndex increment_strided_; - - /// Amount to advance pointer between tiles - LongIndex increment_advance_; - -public: - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile( - TensorRef const &ref, - int thread_idx, - int interleave - ){ - - TensorCoord t = ThreadMap::initial_offset(thread_idx); - long int offset = t[0] * interleave + t[1] * ref.stride()[0]/interleave; - pointer_ = reinterpret_cast(ref.data() + offset); - - stride_ = ref.stride()[0] / interleave; - increment_strided_ = (ref.stride()[0] * sizeof_bits::value / 8) * ThreadMap::Delta::kStrided / interleave; - - increment_advance_ = - (kAdvanceRank == 0 ? - Shape::kContiguous * sizeof_bits::value / 8 : - Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8) / interleave); - } - - /// Loads a fragment - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType const *access_ptr = reinterpret_cast(byte_pointer); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided]; - } - - if (s + 1 < ThreadMap::Iterations::kStrided) { - byte_pointer += increment_strided_; - } - } - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - load_with_pointer_offset( - frag, - tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_ - ); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - - AccessType const *frag_ptr = reinterpret_cast(&frag); - uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = reinterpret_cast(byte_pointer); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int idx = c + s * ThreadMap::Iterations::kContiguous; - access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx]; - } - - if (s + 1 < ThreadMap::Iterations::kStrided) { - byte_pointer += increment_strided_; - } - } - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - store_with_pointer_offset( - frag, - tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ - ); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator++() { - pointer_ += increment_advance_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator--() { - pointer_ -= increment_advance_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - pointer_ += pointer_offset; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - int offset = sizeof_bits::value * - (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; - add_pointer_offset(offset); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorInterleaved<4>; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - - using Underlying = RegularTileIterator2dThreadTile< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap, - kAlignment - >; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the row or column dimensions."); - -private: - - Underlying iterator_; - -public: - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile() { } - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile( - TensorRef const &ref, - int thread_idx - ): - iterator_({ref.data(), ref.stride()}, thread_idx, 4) { - - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - iterator_.load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - iterator_.store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator--() { - --iterator_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -public: - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorInterleaved<4>; - static int const kAdvanceRank = AdvanceRank; - using ThreadMap = ThreadMap_; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using Fragment = Array; - using PitchLinearThreadMap = PitchLinearStripminedThreadMap< layout::PitchLinearShape, - ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >; - - - using Underlying = RegularTileIterator2dThreadTile< - layout::PitchLinearShape, - Element, - layout::PitchLinear, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap - >; - - static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, - "Advance rank may only be along the row or column dimensions."); - -private: - - Underlying iterator_; - -public: - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile() { } - - CUTLASS_DEVICE - RegularTileIterator2dThreadTile( - TensorRef const &ref, - int thread_idx - ): - iterator_({ref.data(), ref.stride()}, thread_idx, 4) { - - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag, TensorCoord const & tile_offset) { - iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); - } - - /// Loads a fragment - CUTLASS_HOST_DEVICE - void load(Fragment &frag) { - iterator_.load_with_pointer_offset(frag, 0); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag, TensorCoord const & tile_offset) { - iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); - } - - /// Stores a fragment - CUTLASS_HOST_DEVICE - void store(Fragment const &frag) { - iterator_.store_with_pointer_offset(frag, 0); - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator++() { - ++iterator_; - return *this; - } - - /// Advances the pointer - CUTLASS_HOST_DEVICE - RegularTileIterator2dThreadTile &operator--() { - --iterator_; - return *this; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h deleted file mode 100644 index 723f328d976fc170d198282823e3da6876ec1ba6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h +++ /dev/null @@ -1,1107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -*/ - -#pragma once - -#include "cutlass/transform/threadblock/regular_tile_iterator.h" -#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator< - Shape_, Element_, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandCongruous::value, - Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - - /// This iterator is specialized for an access size that is 128 bits in length. - static int const kAccessSizeInBits = 128; - - static_assert( - sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - }; - -private: - - /// Element type per access - using AccessType = Array; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = RegularTileAccessIterator; - -private: - - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : address_iterator_(ref, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - address_iterator_.add_tile_offset({0, 1}); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - address_iterator_.add_tile_offset(coord); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, Index byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType const *access_ptr = reinterpret_cast(byte_ptr); - - frag_ptr[access_idx] = *access_ptr; - ++address_iterator_; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, Index byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - *access_ptr = frag_ptr[access_idx]; - ++address_iterator_; - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_byte_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator< - Shape_, Element_, - layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator< - Shape_, Element_, - layout::RowMajorTensorOpMultiplicandCongruous::value, - Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCongruous::value, - Crosswise>, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for crosswise arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>; - - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - }; - - private: - /// Element type per access - using AccessType = Array; - - public: - /// Fragment object to be loaded or stored - using Fragment = - Array; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = RegularTileAccessIterator; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : address_iterator_(ref, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - address_iterator_.add_tile_offset({1, 0}); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - address_iterator_.add_tile_offset(coord); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - address_iterator_.set_iteration_index(0); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); - ++address_iterator_; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); - } - - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, Index byte_offset) { - address_iterator_.set_iteration_index(0); - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; - AccessType *access_ptr = reinterpret_cast(byte_ptr); - - *access_ptr = frag_ptr[access_idx]; - ++address_iterator_; - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - public: - /// Fragment object to be loaded or stored - using Fragment = Array; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator::value, Crosswise>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Crosswise>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::TensorOpMultiplicandCrosswise::value, - Crosswise>, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - public: - /// Fragment object to be loaded or stored - using Fragment = Array; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for k interleaved arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template -class RegularTileIterator< - Shape_, Element_, - layout::TensorOpMultiplicandRowMajorInterleaved::value, - InterleavedK>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandRowMajorInterleaved::value, - InterleavedK>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - /// This iterator is specialized for an access size that is 128 bits in - /// length. - static int const kAccessSizeInBits = 128; - - static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == - kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - }; - - private: - - /// Element type per access - using AccessType = Array; - - public: - /// Fragment object to be loaded or stored - using Fragment = - Array; - - /// Underlying iterator to compute the addresses - using TileAccessIterator = RegularTileAccessIterator; - - private: - // - // Data members - // - - /// Data member to the tile access iterator - TileAccessIterator address_iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : address_iterator_(ref, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - address_iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - address_iterator_.add_pointer_offset(Shape::kCount); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - address_iterator_.set_iteration_index(0); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); - ++address_iterator_; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; - ++address_iterator_; - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for k interleaved arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// - -template -class RegularTileIterator< - Shape_, Element_, - layout::TensorOpMultiplicandColumnMajorInterleaved::value, - InterleavedK>, - AdvanceRank, ThreadMap_, Alignment> { - - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::TensorOpMultiplicandColumnMajorInterleaved::value, - InterleavedK>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - cutlass::MatrixShape, - Element, - layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, - (kAdvanceRank == 1 ? 0 : 1), - ThreadMap - >; - - public: - /// Fragment object to be loaded or stored - using Fragment = Array; - - private: - - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h deleted file mode 100644 index 53121c6114cc3675e4d97f9da65d3ecb58e46d62..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h +++ /dev/null @@ -1,1460 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. - - This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile - first, with the objective of minimizing predicate mask updates during steady-state operation. - - A precomputed "Params" object minimizes the amount of state that must be stored in registers, - and integer addition is used to advance the pointer through memory. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor_op_multiplicand_sm70.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::VoltaTensorOpMultiplicandCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::VoltaTensorOpMultiplicandCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - - /// This iterator is specialized for an access size that is 128 bits in length. - static int const kAccessSizeInBits = 128; - - static_assert( - sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - - ///< Number of pointers - static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); - }; - - -private: - - /// Element type per access - using AccessType = Array; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType * pointer_[Detail::kPointerCount]; - - /// Internal byte offset - Index byte_offset_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Detail::kPointerCount; ++i) { - - // This is the offset of a thread within a threadblock tile for a specific pointer - // (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = - thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; - - // initialize pointer - pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset( - coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + - coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess - ); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = pointer_[s & 1]; - int stride_idx = (s & ~1); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + - c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + - vec_pointer_offset; - - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - AccessType const *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = pointer_[s & 1]; - int stride_idx = (s & ~1); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + - c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + - vec_pointer_offset; - - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::VoltaTensorOpMultiplicandCongruous::value>, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap_>; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::VoltaTensorOpMultiplicandCongruous::value>, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap_>; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; -/// Tile iterator specialized for congruous arrangements for TensorOps -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::VoltaTensorOpMultiplicandBCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::VoltaTensorOpMultiplicandBCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using StrideIndex = typename Layout::Stride::Index; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - - /// This iterator is specialized for an access size that is 128 bits in length. - static int const kAccessSizeInBits = 128; - - static_assert( - sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, - "This iterator requires a policy whose access size is 128bs"); - - ///< Number of pointers - static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); - }; - - -private: - - /// Element type per access - using AccessType = Array; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - // - // Data members - // - - /// Stride value - StrideIndex stride_; - - /// Internal pointer to first access of tile - AccessType * pointer_[Detail::kPointerCount]; - - /// Internal byte offset - Index byte_offset_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Detail::kPointerCount; ++i) { - - // This is the offset of a thread within a threadblock tile for a specific pointer - // (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = - thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; - - // initialize pointer - pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); - - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset( - coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + - coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess - ); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = pointer_[s & 1]; - int stride_idx = (s & ~1); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + - c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + - vec_pointer_offset; - - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - AccessType const *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = pointer_[s & 1]; - int stride_idx = (s & ~1); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + - c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + - vec_pointer_offset; - - int access_idx = c + s * ThreadMap::Iterations::kContiguous; - - char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::VoltaTensorOpMultiplicandBCongruous::value>, - (kAdvanceRank == 0 ? 0 : 1), - ThreadMap_>; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major congruous TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, - Element_, - layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>, - AdvanceRank, - ThreadMap_, - Alignment> { -public: - - static_assert(AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, - Element, - layout::VoltaTensorOpMultiplicandBCongruous::value>, - (kAdvanceRank == 0 ? 1 : 0), - ThreadMap_>; - -public: - - /// Fragment object to be loaded or stored - using Fragment = Array; - -private: - - /// Underlying iterator - UnderlyingIterator iterator_; - -public: - - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator( - TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ): iterator_({ref.data(), ref.stride()}, thread_id) { - - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - load_with_pointer_offset(frag, 0); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset( - Fragment const &frag, - Index pointer_offset) { - - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); - } -}; - - -/// Tile iterator specialized for crosswise arrangements for TensorOps. -/// -/// Volta TN SMEM layout is a little diffrent: -/// Crosseised elements will be stored in a line, while contiguous elements -/// sre stored in line-by-line. -/// Padding is used to reduce SMEM bank conflicts. -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator< - Shape_, Element_, - layout::VoltaTensorOpMultiplicandCrosswise::value, - Shape_::kContiguous>, - AdvanceRank, ThreadMap_, Alignment> { - - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = - layout::VoltaTensorOpMultiplicandCrosswise::value, - Shape::kContiguous>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Internal details made public to facilitate introspection - struct Detail { - - ///< Number of pointers - static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); - - /// Iterations for the kElementsPerAccess of ThreadMap - static int const kIterarionsPerAccess = - ThreadMap::kElementsPerAccess / Layout::kElementsPerAccess; - - /// Contiguous elements per line - static int const kContiguousElementsPerLine = 4; - }; - - private: - /// Element type per access - using AccessType = Array; - - public: - /// Fragment object to be loaded or stored - using Fragment = - Array; - - private: - // - // Data members - // - - /// The crosswised elements will be stored in a line. - /// line_size is size of crosswised dimension plus padding. - /// in units of AccessType - Index line_size; - - /// Internal pointer to first access of tile - AccessType *pointer_[Detail::kPointerCount]; - - /// Internal byte offset - Index byte_offset_; - - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : line_size(ref.stride(0) * Detail::kContiguousElementsPerLine / Layout::kElementsPerAccess), - byte_offset_(0) { - - layout::PitchLinearCoord thread_offset_base = - ThreadMap::initial_offset(thread_id); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Detail::kPointerCount; ++i) { - // This is the offset of a thread within a threadblock tile for a specific - // pointer (units of elements) - layout::PitchLinearCoord thread_offset_in_threadblock_tile = - thread_offset_base + - layout::PitchLinearCoord{ - 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; - - // initialize pointer - pointer_[i] = reinterpret_cast( - ref.data() + ref.offset(thread_offset_in_threadblock_tile)); - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_offset_ += pointer_offset * sizeof(Element); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - // (Shape::kContiguous/Layout::kElementsPerAccess)* - // line_size * Layout::kElementsPerAccess - add_pointer_offset(Shape::kContiguous * line_size); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - this->operator++(); - - return prev; - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - add_pointer_offset((coord.contiguous() * (Shape::kContiguous / Layout::kElementsPerAccess) * - line_size + coord.strided() * Shape::kStrided) * - Layout::kElementsPerAccess); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - AccessType *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - AccessType *access_ptr = pointer_[(s & 1) ^ (s / 2)]; - - access_ptr += 16 * (s / 2); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { - - int access_offset = - c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + - vec_pointer_offset + i * line_size; - - int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * - Detail::kIterarionsPerAccess + i; - - char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - frag_ptr[access_idx] = *reinterpret_cast( - access_byte_ptr + byte_offset_); - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - AccessType const *frag_ptr = reinterpret_cast(&frag); - - Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - - AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; - - access_ptr += 16 * (s / 2) + vec_pointer_offset; - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - CUTLASS_PRAGMA_UNROLL - for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { - - int access_offset = - c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; - - int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * - Detail::kIterarionsPerAccess + i; - - char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); - - *reinterpret_cast(access_byte_ptr + byte_offset_) = - frag_ptr[access_idx]; - } - } - } - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for column-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator::value, Shape_::kRow>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for column-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kRow>; - static int const kAdvanceRank = AdvanceRank; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::VoltaTensorOpMultiplicandCrosswise::value, - Shape::kRow>, - (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; - - public: - /// Fragment object to be loaded or stored - using Fragment = Array; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.row(), coord.column()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile Iterator specialized for row-major crosswise TensorOp formats. -/// -/// -/// Satisfies: ForwardTileIteratorConcept | -/// ReadableContiguousTileIteratorConcept | -/// WriteableContiguousTileIteratorConcept -/// -template < - typename Shape_, - typename Element_, - int AdvanceRank, - typename ThreadMap_, - int Alignment -> -class RegularTileIterator::value, Shape_::kColumn>, - AdvanceRank, ThreadMap_, Alignment> { - public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for row-major iterator may along advance along the " - "columns(rank=0) or rows(rank=1) dimension."); - - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kColumn>; - static int const kAdvanceRank = AdvanceRank; - static int const kAlignment = Alignment; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorCoord = typename Layout::TensorCoord; - - using ThreadMap = ThreadMap_; - - /// Underlying iterator type - using UnderlyingIterator = RegularTileIterator< - layout::PitchLinearShape, Element, - layout::VoltaTensorOpMultiplicandCrosswise::value, - Shape::kColumn>, - (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; - - public: - /// Fragment object to be loaded or stored - using Fragment = Array; - - private: - /// Underlying iterator - UnderlyingIterator iterator_; - - public: - /// Construct a TileIterator with zero threadblock offset - CUTLASS_HOST_DEVICE - RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) - : iterator_({ref.data(), ref.stride()}, thread_id) {} - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - iterator_.add_pointer_offset(pointer_offset); - } - - /// Adds a tile offset - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { - iterator_.add_tile_offset({coord.column(), coord.row()}); - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator &operator++() { - ++iterator_; - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - RegularTileIterator operator++(int) { - RegularTileIterator prev(*this); - ++iterator_; - - return prev; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - iterator_.load_with_pointer_offset(frag, pointer_offset); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { - iterator_.store_with_pointer_offset(frag, pointer_offset); - } - - /// Store a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h deleted file mode 100644 index 8e5d181c177b2ad6627c927ae4ad3fb9c99a96d3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h +++ /dev/null @@ -1,149 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template wraps the vector access iterator concept to load whole vector from tensors in - memory. This is typically used for per-channel scale and bias in convolution kernels. -*/ - -#pragma once - -#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace transform { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class VectorIterator { -public: - using VectorAccessIterator = VectorAccessIterator_; - - using Shape = typename VectorAccessIterator::Shape; - using Element = typename VectorAccessIterator::Element; - using Layout = typename VectorAccessIterator::Layout; - using TensorCoord = typename Layout::TensorCoord; - using AccessType = typename VectorAccessIterator::AccessType; - using TensorRef = typename VectorAccessIterator::TensorRef; - using Index = typename VectorAccessIterator::Index; - using LongIndex = typename VectorAccessIterator::LongIndex; - - static int const kElementsPerAccess = VectorAccessIterator::kElementsPerAccess; - static int const kRowsPerIteration = VectorAccessIterator::kRowsPerIteration; - static int const kThreads = VectorAccessIterator::kThreads; - static int const kIterations = VectorAccessIterator::kIterations; - - /// Fragment object to be loaded or stored - using Fragment = cutlass::Array< - Element, kElementsPerAccess * kIterations>; - -private: - - /// Internal state - VectorAccessIterator vector_access_iterator_; - -public: - - /// Constructor - CUTLASS_HOST_DEVICE - VectorIterator( - Element const *ptr, - TensorCoord extent, - int thread_idx, - int warp_idx, - MatrixCoord const &threadblock_offset = MatrixCoord() - ): - vector_access_iterator_(ptr, extent, thread_idx, warp_idx, threadblock_offset) { } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - VectorIterator &operator++() { - vector_access_iterator_.advance(); - return *this; - } - - /// Advances to the next tile in memory. - CUTLASS_HOST_DEVICE - VectorIterator operator++(int) { - VectorIterator self(*this); - operator++(); - return self; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { - - frag.clear(); - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < kIterations; ++c) { - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[c], - vector_access_iterator_.get() + pointer_offset, - vector_access_iterator_.valid() - ); - - ++vector_access_iterator_; - } -// } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - vector_access_iterator_.set_iteration_index(0); - load_with_pointer_offset(frag, 0); - } - - CUTLASS_DEVICE - void advance() { - vector_access_iterator_.advance(); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace transform -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h deleted file mode 100644 index b27b77f9b697476ed54a019cd94120561371ebd1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h +++ /dev/null @@ -1,283 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -/*! \file - \brief This defines a "fragment" iterator for visiting the fragments of a warp vector - that participate in one warp-level mma operation. - - Typically, this is used to access the scale/bias fragment of a warp-level mma operation. - The scale/bias vector is then partitioned into smaller fragments that can be fed into - next warp-level mma operation. - - This iterator is necessary to accomplish warp-level mma fusion where the scale/bias vector is - applied to the multiplicand for the next mma. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_conversion.h" - -namespace cutlass { -namespace transform { -namespace warp { - - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the input fragment tile shape (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Layout of operand in memory - typename Layout_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - //// Number of elements per access when loading fragment - int ElementsPerAccess> -class VectorFragmentIterator; - - -// Partial specialization for PitchLinear layout tile - -template < - /// Size of the input fragment vector shape (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - //// Number of elements per access when loading fragment - int ElementsPerAccess> -class VectorFragmentIterator { - public: - - /// Size of the input threadblock tile shape (concept: MatrixShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::PitchLinear; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Number of participating threads - static int const kThreads = 32; - - static int const kElementsPerAccess = ElementsPerAccess; - static int const kRowsPerIteration = 8; - static int const kColumnsPerAccess = 8; - static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kK / kThreads; - static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; - - /// Number of iterations - using Iterations = MatrixShape; - -public: - - // - // Derived quantities - // - // All fragments have kElementsPerAccess scale followed by bias - - /// Fragment object holding a thread's part of a tile - /// This is the fragment size produced by one iteration of the iterator. - using Fragment = Array; - - /// Input threadblock fragment tile - using ThreadblockFragment = Array; - -private: - - /// Internal access type - using AccessType = Array; - -private: - // - // Data members - // - - /// Input threadblock fragment tile - AccessType const *iterator_; - - /// Internal index - int index_; - -public: - /// Constructs an iterator - CUTLASS_HOST_DEVICE - VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) - : iterator_(reinterpret_cast(&threadblock_frag)), - index_(0) {} - - /// Add offset - CUTLASS_HOST_DEVICE - void add_offset(int index_offset) { - index_ += index_offset; - - if(index_ >= Iterations::kColumn) - index_ = 0; - } - - /// Increments - CUTLASS_HOST_DEVICE - VectorFragmentIterator &operator++() { - add_offset(1); - return *this; - } - - CUTLASS_HOST_DEVICE - void set_index(int idx) { - index_ = idx; - } - - /// Loads a fragment from the referenced part of the accumulator tile - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int r = 0; r < Iterations::kRow; r++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kAccessPerIteration; i++) { - - frag_ptr[i * Iterations::kRow + r].clear(); - frag_ptr[i * Iterations::kRow + r] = iterator_[index_ * kAccessPerIteration + i]; - } - } - } - -}; - -// Partial specialization for Row-Major layout tile - -template < - /// Size of the input fragment tile shape (concept: MatrixShape) - typename Shape_, - /// Element type - typename Element_, - /// Shape of one matrix product operation (concept: MatrixShape) - typename InstructionShape_, - //// Number of elements per access when loading fragment - int ElementsPerAccess> -class VectorFragmentIterator { - public: - - /// Size of the input threadblock tile shape (concept: MatrixShape) - using Shape = Shape_; - - /// Element type - using Element = Element_; - - /// Layout of source tile - using Layout = cutlass::layout::RowMajor; - - /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = InstructionShape_; - - /// Underlying iterator - using Base = VectorFragmentIterator< - layout::PitchLinearShape, Element, - layout::PitchLinear, InstructionShape, ElementsPerAccess>; - - - public: - - // - // Derived quantities - // - /// Fragment object holding a thread's part of a tile - /// This is the fragment size produced by one iteration of the iterator. - using Fragment = typename Base::Fragment; - - /// Input threadblock fragment tile - using ThreadblockFragment = typename Base::ThreadblockFragment; - - private: - /// Underlying iterator - Base iterator_; - -public: - /// Constructs an iterator - CUTLASS_HOST_DEVICE - VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) - : iterator_(threadblock_frag) {} - - /// Add offset - CUTLASS_HOST_DEVICE - void add_offset(int index_offset) { - iterator_.add_offset(index_offset); - } - - /// Increments - CUTLASS_HOST_DEVICE - VectorFragmentIterator &operator++() { - add_offset(1); - return *this; - } - - CUTLASS_HOST_DEVICE - void set_index(int idx) { - iterator_.set_index(idx); - } - - /// Loads a fragment from the referenced part of the accumulator tile - CUTLASS_HOST_DEVICE - void load(Fragment &frag) const { - iterator_.load(frag); - } - -}; - - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace conv -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h deleted file mode 100644 index 68896d6b60767221fd41421a0d3fdf75392c3604..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h +++ /dev/null @@ -1,269 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. -*/ -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#include -#include -#include -#include -#endif - - -/// Optionally enable GCC's built-in type -#if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) -#define CUTLASS_UINT128_NATIVE -#elif !defined(__CUDA_ARCH__) -// No custom support for 128b arithmetic on device -#if defined(_MSC_VER) && defined(_M_AMD64) -#define CUTLASS_INT128_ARITHMETIC -#include -#if _MSC_VER >= 1920 && !defined(__CUDA_ARCH__) -#define CUTLASS_INT128_ARITHMETIC_DIV -#include -#endif -#endif -#endif - -namespace cutlass { - -///! Unsigned 128b integer type -struct alignas(16) uint128_t -{ - /// Size of one part of the uint's storage in bits - static constexpr int storage_bits_ = 64; - - struct hilo - { - uint64_t lo; - uint64_t hi; - }; - - // Use a union to store either low and high parts or, if present, a built-in 128b integer type. - union { - struct hilo hilo_; - -#if defined(CUTLASS_UINT128_NATIVE) - unsigned __int128 native; -#endif // defined(CUTLASS_UINT128_NATIVE) - }; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - uint128_t() : hilo_{0, 0} {} - - /// Constructor from uint64 - CUTLASS_HOST_DEVICE - uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} - - /// Constructor from two 64b unsigned integers - CUTLASS_HOST_DEVICE - uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} - - /// Optional constructor from native value -#if defined(CUTLASS_UINT128_NATIVE) - uint128_t(unsigned __int128 value) : native(value) { } -#endif - - /// Lossily cast to uint64 - CUTLASS_HOST_DEVICE - explicit operator uint64_t() const - { - return hilo_.lo; - } - - CUTLASS_HOST_DEVICE - static void exception() - { -#if defined(__CUDA_ARCH__) - asm volatile (" brkpt;\n"); -#else - // throw std::runtime_error("Not yet implemented."); - abort(); -#endif - } - - /// Add - CUTLASS_HOST_DEVICE - uint128_t operator+(uint128_t const& rhs) const - { - uint128_t y{}; -#if defined(CUTLASS_UINT128_NATIVE) - y.native = native + rhs.native; -#else - y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; - y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (y.hilo_.lo < hilo_.lo); -#endif - return y; - } - - /// Subtract - CUTLASS_HOST_DEVICE - uint128_t operator-(uint128_t const& rhs) const - { - uint128_t y{}; -#if defined(CUTLASS_UINT128_NATIVE) - y.native = native - rhs.native; -#else - y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; - y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); -#endif - return y; - } - - /// Multiply by unsigned 64b integer yielding 128b integer - CUTLASS_HOST_DEVICE - uint128_t operator*(uint64_t const& rhs) const - { - uint128_t y{}; -#if defined(CUTLASS_UINT128_NATIVE) - y.native = native * rhs; -#elif defined(CUTLASS_INT128_ARITHMETIC) - // Multiply by the low part - y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); - - // Add the high part and ignore the overflow - uint64_t overflow{0}; - y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -#else - CUTLASS_UNUSED(rhs); - exception(); -#endif - return y; - } - - /// Divide 128b operation by 64b operation yielding a 64b quotient - CUTLASS_HOST_DEVICE - uint64_t operator/(uint64_t const& divisor) const - { - uint64_t quotient{0}; -#if defined(CUTLASS_UINT128_NATIVE) - quotient = uint64_t(native / divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - uint64_t remainder{0}; - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -#else - CUTLASS_UNUSED(divisor); - exception(); -#endif - return quotient; - } - - /// Divide 128b operation by 64b operation yielding a 64b quotient - CUTLASS_HOST_DEVICE - uint64_t operator%(uint64_t const& divisor) const - { - uint64_t remainder{0}; -#if defined(CUTLASS_UINT128_NATIVE) - remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -#else - CUTLASS_UNUSED(divisor); - exception(); -#endif - return remainder; - } - - /// Computes the quotient and remainder in a single method. - CUTLASS_HOST_DEVICE - uint64_t divmod(uint64_t &remainder, uint64_t divisor) const - { - uint64_t quotient{0}; -#if defined(CUTLASS_UINT128_NATIVE) - quotient = uint64_t(native / divisor); - remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -#else - CUTLASS_UNUSED(remainder); - CUTLASS_UNUSED(divisor); - exception(); -#endif - return quotient; - } - - /// Left-shifts a 128b unsigned integer - CUTLASS_HOST_DEVICE - uint128_t operator<<(int sh) const - { - if (sh == 0) { - return *this; - } - else if (sh >= storage_bits_) { - return uint128_t(0, hilo_.lo << (sh - storage_bits_)); - } - else { - return uint128_t( - (hilo_.lo << sh), - (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) - ); - } - } - - /// Right-shifts a 128b unsigned integer - CUTLASS_HOST_DEVICE - uint128_t operator>>(int sh) const - { - if (sh == 0) { - return *this; - } - else if (sh >= storage_bits_) { - return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); - } - else { - return uint128_t( - (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), - (hilo_.hi >> sh) - ); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h deleted file mode 100644 index 3657853557ebccfd6be63ce6ba0fa4d69880d649..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h +++ /dev/null @@ -1,93 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Defines an unsigned 256b integer. -*/ - -#pragma once -#include "cutlass/cutlass.h" -#if defined(__CUDACC_RTC__) -#include CUDA_STD_HEADER(cstdint) -#else -#include -#include -#include -#include -#include -#endif -#include "cutlass/uint128.h" - -namespace cutlass { - -///! Unsigned 256b integer type -struct alignas(32) uint256_t { - /// Size of one part of the uint's storage in bits - static constexpr int storage_bits_ = 128; - - struct hilo { - uint128_t lo; - uint128_t hi; - }; - - // Use a union to store either low and high parts. - union { - struct hilo hilo_; - }; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - uint256_t() : hilo_{uint128_t{}, uint128_t{}} {} - - /// Constructor from uint128 - CUTLASS_HOST_DEVICE - uint256_t(uint128_t lo_) : hilo_{lo_, uint128_t{}} {} - - /// Constructor from two 128b unsigned integers - CUTLASS_HOST_DEVICE - uint256_t(uint128_t lo_, uint128_t hi_) : hilo_{lo_, hi_} {} - - /// Lossily cast to uint128_t - CUTLASS_HOST_DEVICE - explicit operator uint128_t() const { - return hilo_.lo; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h deleted file mode 100644 index 57a73a5fbb41a22ed5e44743c84fa1bbbe0b0075..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h +++ /dev/null @@ -1,80 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include - -#define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 2 -#define CUTLASS_PATCH 1 - -#ifdef CUTLASS_VERSIONS_GENERATED -#include "cutlass/version_extended.h" -#else -#define CUTLASS_BUILD 0 -#define CUTLASS_REVISION "" -#endif - -#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) - -namespace cutlass { - - inline constexpr uint32_t getVersion() { - return CUTLASS_VERSION; - } - inline constexpr uint32_t getVersionMajor() { - return CUTLASS_MAJOR; - } - inline constexpr uint32_t getVersionMinor() { - return CUTLASS_MINOR; - } - inline constexpr uint32_t getVersionPatch() { - return CUTLASS_PATCH; - } - inline constexpr uint32_t getVersionBuild() { - return CUTLASS_BUILD + 0; - } - - inline std::string getVersionString() { - std::string version = "@CUTLASS_VERSION@"; - if (getVersionBuild()) { - version += "." + std::to_string(getVersionBuild()); - } - return version; - } - - inline std::string getGitRevision() { - return "@CUTLASS_REVISION@"; - } - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h deleted file mode 100644 index 77929f60f73dc07ea2a8e47de1cfb95b5f8859f0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h +++ /dev/null @@ -1,133 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types - and is safe to use in a union. -*/ - -#pragma once - -#include "cutlass/arch/wmma.h" - -#if defined(CUTLASS_ARCH_WMMA_ENABLED) - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Wmma array type (WmmaFragmentArray holds elements of type nvcuda::wmma::fragment) -template < - /// Element type - typename T, - /// Number of elements in the array - int N, - /// Whether the element type of T is half_t or __half - bool IsHalfType = (platform::is_same::value || - platform::is_same::value) -> -class WmmaFragmentArray: public Array { -public: - - /// Efficient clear method (override Array::clear()) - CUTLASS_HOST_DEVICE - void clear() - { - for(int i = 0; i < Array::kElements; i++) - { - nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); - } - } - - CUTLASS_HOST_DEVICE - WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) - { - using element_type = typename T::element_type; - plus add; - - for (int i = 0; i < Array::kElements; i++) - { - (*this)[i] = add((*this)[i], rhs[i]); - } - - return *this; - } -}; - -/// Partial specialization for the case in which T::element_type is -/// half_t or __half. This is needed because the cast (typename T::element_type)0 -/// in the primary template flags as an error when __CUDA_NO_HALF_CONVERSIONS__ -/// is set. -template < - /// Element type - typename T, - /// Number of elements in the array - int N -> -class WmmaFragmentArray: public Array { -public: - - /// Efficient clear method (override Array::clear()) - CUTLASS_HOST_DEVICE - void clear() - { - for(int i = 0; i < Array::kElements; i++) - { - nvcuda::wmma::fill_fragment((*this)[i], __float2half(0.f)); - } - } - - CUTLASS_HOST_DEVICE - WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) - { - using element_type = typename T::element_type; - plus add; - - for (int i = 0; i < Array::kElements; i++) - { - (*this)[i] = add((*this)[i], rhs[i]); - } - - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h deleted file mode 100644 index 485ebbe3ae27af7ddc05bc1e36f32b1a4ee65901..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h +++ /dev/null @@ -1,154 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Utilities for initializing workspaces -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) -#include "cuda.h" -#include "cuda_runtime.h" - -#include "cutlass/trace.h" -#endif - -#include "cutlass.h" -#include "cutlass/cuda_host_adapter.hpp" - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int MinWorkspaceAlignment = 16; - -#if !defined(__CUDACC_RTC__) -static Status -zero_workspace( - void* workspace, - size_t workspace_size, - cudaStream_t stream = nullptr, - [[maybe_unused]] CudaHostAdapter *cuda_adapter = nullptr) { - if (workspace_size > 0) { - if (workspace == nullptr) { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } - - CUTLASS_TRACE_HOST(" clearing workspace"); - -#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { - return Status::kErrorInternal; - } - } - else { - return Status::kErrorInternal; - } -#else - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } -#endif - } - - return Status::kSuccess; -} -#endif - -#if !defined(__CUDACC_RTC__) -template -Status -fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported fill type"); - if (fill_count > 0) { - if (workspace == nullptr) { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } - - CUTLASS_TRACE_HOST(" filling workspace"); - -#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER - // - // Use the cuda host adapter - // - CUTLASS_ASSERT(cuda_adapter); - if (cuda_adapter) { - if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, fill_value, fill_count, stream)) { - return Status::kErrorInternal; - } - } - else { - return Status::kErrorInternal; - } -#else - CUdeviceptr d_workspace = reinterpret_cast(workspace); - CUresult result = CUDA_SUCCESS; - if (sizeof(T) == 4) { - result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); - } - else if (sizeof(T) == 2) { - result = cuMemsetD16Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); - } - else if (sizeof(T) == 1) { - result = cuMemsetD8Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); - } - - if (CUDA_SUCCESS != result) { - const char** error_string_ptr = nullptr; - (void) cuGetErrorString(result, error_string_ptr); - if (error_string_ptr != nullptr) { - CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned error " << *error_string_ptr); - } - else { - CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned unrecognized error"); - } - return Status::kErrorInternal; - } -#endif - } - - return Status::kSuccess; -} -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py deleted file mode 100644 index cbb617dc20d35f6dd352a84c3964a58fa9bc687e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -# Local module imports -from .dsl import * -from .runtime import * -from ._mlir_helpers import lru_cache_ir -from .env_manager import get_str_env_var, detect_gpu_arch - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py deleted file mode 100644 index 607a24d032c6ef899b586a41d2bb771c381406b0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides MLIR Dialect helper functions -""" - -from . import arith -from .lru_cache_ir import lru_cache_ir - - -__all__ = ["arith", "lru_cache_ir"] - -try: - from . import gpu - - __all__.extend(["gpu"]) -except ImportError: - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py deleted file mode 100644 index 60cc8db31fd7369d721f3d7c64c5bb8fb03502a8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py +++ /dev/null @@ -1,691 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides MLIR Arith Dialect helper functions -""" - -import array -import numpy as np - -from ..common import * -from ..._mlir import ir # type: ignore -from ..._mlir.extras import types as T # type: ignore -from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore - -from .lru_cache_ir import lru_cache_ir - -# ============================================================================= -# Arith Dialect Helper functions -# ============================================================================= - - -def recast_type(src_type, res_elem_type) -> ir.Type: - if isinstance(src_type, T.VectorType): - if src_type.scalable: - res_type = T.vector( - *src_type.shape, - res_elem_type, - scalable=src_type.scalable, - scalable_dims=src_type.scalable_dims, - ) - else: - res_type = T.vector(*src_type.shape, res_elem_type) - elif isinstance(src_type, T.RankedTensorType): - res_type = T.RankedTensorType.get( - element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides - ) - elif isinstance(src_type, T.UnrankedTensorType): - res_type = T.UnrankedTensorType.get(element_type=res_elem_type) - elif isinstance(src_type, T.MemRefType): - res_type = T.MemRefType.get( - element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides - ) - else: - res_type = res_elem_type - return res_type - - -def is_scalar(ty) -> bool: - return not isinstance( - ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType) - ) - - -def element_type(ty) -> ir.Type: - if not is_scalar(ty): - return ty.element_type - else: - return ty - - -def is_narrow_precision(ty) -> bool: - narrow_types = { - T.f8E8M0FNU(), - T.f8E4M3FN(), - T.f8E4M3(), - T.f8E5M2(), - T.f8E4M3B11FNUZ(), - T.f4E2M1FN(), - T.f6E3M2FN(), - T.f6E2M3FN(), - } - return ty in narrow_types - - -def is_float_type(ty) -> bool: - return ( - arith._is_float_type(ty) - # TODO-upstream: prediction is not correct. Patch here and fix in upstream later - or is_narrow_precision(ty) - or ty in (T.bf16(), T.tf32()) - ) - - -def truncf_to_narrow(res_ty, src, loc, ip): - res_elem_ty = element_type(res_ty) - if res_elem_ty == T.f8E8M0FNU(): - rnd = nvgpu.RoundingMode.RP - else: - rnd = nvgpu.RoundingMode.RN - return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip) - - -def extf_from_narrow(res_ty, src, loc, ip): - src_elem_ty = element_type(src.type) - - # When source type is E8M0, temporary element type has to be bf16 - tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16() - tmp_ty = recast_type(src.type, tmp_elem_ty) - - # narrow -> bf16/f16 -> target type - tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip) - return arith.extf(res_ty, tmp, loc=loc, ip=ip) - - -def bitcast(src, res_elem_type, *, loc=None, ip=None): - res_type = recast_type(src.type, res_elem_type) - return arith.bitcast(res_type, src, loc=loc, ip=ip) - - -def cvtf(src, res_elem_type, *, loc=None, ip=None): - src_elem_type = element_type(src.type) - - if res_elem_type == src_elem_type: - return src - - res_type = recast_type(src.type, res_elem_type) - - # Treat TF32 as F32 and use i32 as intermediate data - # TODO-upstream: update arith to support tf32 <-> f32 conversion - if src_elem_type == T.tf32(): - # tf32 -> i32 - tmp_type = recast_type(src.type, T.i32()) - src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip) - # i32 -> f32 - src = bitcast(src, T.f32(), loc=loc, ip=ip) - # f32 -> X with `cvtf` recursively - return cvtf(src, res_elem_type, loc=loc, ip=ip) - - if res_elem_type == T.tf32(): - # X -> f32 with `cvtf`` recursively - tmp = cvtf(src, T.f32(), loc=loc, ip=ip) - # f32 -> i32 - tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip) - # i32 -> tf32 - return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip) - - if res_elem_type.width > src_elem_type.width: - if is_narrow_precision(src_elem_type): - return extf_from_narrow(res_type, src, loc, ip) - else: - return arith.extf(res_type, src, loc=loc, ip=ip) - else: - tmp_mlir_type = recast_type(src.type, T.f32()) - - # f16 -- extf -> f32 -- truncf -> bf16 - # TODO-upstream: update arith to support bf16 <-> f16 conversion? - if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or ( - src_elem_type == T.bf16() and res_elem_type == T.f16() - ): - tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip) - return arith.truncf(res_type, tmp, loc=loc, ip=ip) - - # {f8, f6, f4} -> f16, f32, ... - elif is_narrow_precision(res_elem_type): - return truncf_to_narrow(res_type, src, loc, ip) - else: - return arith.truncf(res_type, src, loc=loc, ip=ip) - - -def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): - res_type = recast_type(src.type, res_elem_type) - # TODO-upstream: update arith to support this kind of conversion - if element_type(src.type) in (T.tf32(), T.bf16()): - src = cvtf(src, T.f32(), loc=loc, ip=ip) - - if signed: - return arith.fptosi(res_type, src, loc=loc, ip=ip) - else: - return arith.fptoui(res_type, src, loc=loc, ip=ip) - - -def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): - res_type = recast_type(src.type, res_elem_type) - - orig_res_type = res_type - # TODO-upstream: update arith to support this kind of conversion - if res_elem_type in (T.tf32(), T.bf16()): - res_type = recast_type(src.type, T.f32()) - - if signed and element_type(src.type).width > 1: - res = arith.sitofp(res_type, src, loc=loc, ip=ip) - else: - res = arith.uitofp(res_type, src, loc=loc, ip=ip) - - if orig_res_type == res_type: - return res - - return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip) - - -def int_to_int(a, dst_elem_type, *, loc=None, ip=None): - src_signed = a.signed - dst_signed = dst_elem_type.signed - src_width = element_type(a.type).width - dst_width = dst_elem_type.width - - dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type) - - if dst_width == src_width: - return a - elif src_signed != False and not dst_signed: - # Signed -> Unsigned - if dst_width > src_width: - return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) - else: - return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) - elif src_signed == dst_signed: - # Same signedness - if dst_width > src_width: - if src_signed != False and src_width > 1: - return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip) - else: - return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) - else: - return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) - else: - # Unsigned -> Signed - if dst_width > src_width: - return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) - else: - # For truncation from unsigned to signed, we need to handle overflow - # First truncate to the target width - trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) - # Then reinterpret as signed - if dst_signed: - return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip) - return trunc - - -# ============================================================================= -# Arith Ops Emitter Helpers -# - assuming type of lhs and rhs match each other -# - op name matches python module operator -# ============================================================================= - - -def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): - """ - This function provides simplified interface to upstream op builder - arith.truncf(T.vector(shape, new_type), src) - - is simplified as because it's element-wise op which can't change shape - arith.truncf(new_type, src) - """ - if isinstance(src, ir.Value): - src_ty = src.type - else: - src_ty = type(src).mlir_type - src = src.ir_value() - - src_elem_ty = element_type(src_ty) - - if src_elem_ty == res_elem_ty: - return src - elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty): - # float-to-float - return cvtf(src, res_elem_ty, loc=loc, ip=ip) - elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type( - res_elem_ty - ): - if src_elem_ty.width >= res_elem_ty.width: - cast_op = arith.trunci - else: - if is_signed: - cast_op = arith.extsi - else: - cast_op = arith.extui - - res_ty = recast_type(src_ty, res_elem_ty) - return cast_op(res_ty, src, loc=loc, ip=ip) - elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty): - return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip) - elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty): - return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip) - else: - raise DSLRuntimeError( - f"cast from {src_elem_ty} to {res_elem_ty} is not supported" - ) - - -@lru_cache_ir() -def const(value, ty=None, *, loc=None, ip=None): - """ - Generates dynamic expression for constant values. - """ - from ..typing import Numeric, NumericMeta - from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type - - if isinstance(value, Numeric): - value = value.value - - # Early return - if is_dynamic_expression(value) and ( - value.type.isinstance(value.type) or T.bool().isinstance(value.type) - ): - return value - - # Assume type - if ty is None: - if isinstance(value, float): - ty = T.f32() - elif isinstance(value, bool): - ty = T.bool() - elif isinstance(value, int): - ty = T.i32() - elif isinstance(value, np.ndarray): - ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype)) - value = array.array(value.dtype.kind, value.flatten().tolist()) - else: - raise DSLNotImplemented(f"{type(value)} is not supported") - elif isinstance(ty, NumericMeta): - ty = ty.mlir_type - elif isinstance(ty, ir.Type): - if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty): - elem_ty = ty.element_type - if isinstance(elem_ty, ir.IntegerType): - attr = ir.IntegerAttr.get(elem_ty, value) - else: - attr = ir.FloatAttr.get(elem_ty, value) - value = ir.DenseElementsAttr.get_splat(ty, attr) - elif arith._is_float_type(ty) and isinstance(value, (bool, int)): - value = float(value) - elif arith._is_integer_like_type(ty) and isinstance(value, float): - value = int(value) - else: - raise DSLNotImplemented(f"type {ty} is not supported") - - return arith.constant(ty, value, loc=loc, ip=ip) - - -def _dispatch_to_rhs_r_op(op): - """Decorator that dispatches to the right-hand-side's reverse operation. - - If the other operand is not an ArithValue or is a subclass (more specific) - of ArithValue, this allows proper method resolution for binary operations. - """ - - def wrapper(self, other, **kwargs): - if not isinstance(other, ArithValue): - if not isinstance(other, (int, float, bool)): - # allows to call other.__rmul__ - return NotImplemented - - return op(self, other, **kwargs) - - return wrapper - - -def _binary_op(op): - """ - Decorator to check if the 'other' argument is an ArithValue. - If not, returns NotImplemented. - """ - - def wrapper(self, other, **kwargs): - # When reach this point, `self` must be cast to base `ArithValue` type - if isinstance(other, (int, float, bool)): - other = const(other, self.type).with_signedness(self.signed) - - # Call the original function - # If sub-class doesn't implement overloaded arithmetic, cast to base class - return op(self, other, **kwargs) - - return wrapper - - -# Operator overloading -@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid) -@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid) -@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid) -@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid) -@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid) -@ir.register_value_caster(ir.Float8E5M2Type.static_typeid) -@ir.register_value_caster(ir.Float8E4M3Type.static_typeid) -@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid) -@ir.register_value_caster(ir.BF16Type.static_typeid) -@ir.register_value_caster(ir.F16Type.static_typeid) -@ir.register_value_caster(ir.FloatTF32Type.static_typeid) -@ir.register_value_caster(ir.F32Type.static_typeid) -@ir.register_value_caster(ir.F64Type.static_typeid) -@ir.register_value_caster(ir.IntegerType.static_typeid) -@ir.register_value_caster(ir.VectorType.static_typeid) -@ir.register_value_caster(ir.RankedTensorType.static_typeid) -class ArithValue(ir.Value): - """Overloads operators for MLIR's Arith dialects binary operations.""" - - def __init__(self, v, signed: Union[bool, None] = None): - if isinstance(v, int): - v = arith.constant(self.type, v) - super().__init__(v) - - elem_ty = element_type(self.type) - self.is_float = arith._is_float_type(elem_ty) - # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL - self.signed = signed and elem_ty.width > 1 - - def with_signedness(self, signed: Union[bool, None]): - return type(self)(self, signed) - - def __neg__(self, *, loc=None, ip=None): - if self.type == T.bool(): - raise TypeError( - "Negation, the operator `-` is not supported for boolean type" - ) - - if self.is_float: - return arith.negf(self, loc=loc, ip=ip) - else: - c0 = arith.constant(self.type, 0, loc=loc, ip=ip) - return arith.subi(c0, self, loc=loc, ip=ip) - - @_binary_op - def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float and other.is_float: - return math.powf(self, other, loc=loc, ip=ip) - elif self.is_float and not other.is_float: - return math.fpowi(self, other, loc=loc, ip=ip) - elif not self.is_float and other.is_float: - lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) - rhs = cvtf(other, T.f32(), loc=loc, ip=ip) - return math.powf(lhs, rhs, loc=loc, ip=ip) - elif not self.is_float and not other.is_float: - return math.ipowi(self, other, loc=loc, ip=ip) - else: - raise DSLNotImplemented(f"Unsupported '{self} ** {other}'") - - @_binary_op - def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__pow__(self, loc=loc, ip=ip) - - # arith operators - - @_dispatch_to_rhs_r_op - @_binary_op - def __add__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.addf(self, other, loc=loc, ip=ip) - else: - return arith.addi(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.subf(self, other, loc=loc, ip=ip) - else: - return arith.subi(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.mulf(self, other, loc=loc, ip=ip) - else: - return arith.muli(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.divf(self, other, loc=loc, ip=ip) - else: - lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) - rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip) - return arith.divf(lhs, rhs, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - q = arith.divf(self, other, loc=loc, ip=ip) - return math.floor(q, loc=loc, ip=ip) - elif self.signed != False: - return arith.floordivsi(self, other, loc=loc, ip=ip) - else: - return arith.divui(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.remf(self, other, loc=loc, ip=ip) - elif self.signed != False: - return arith.remsi(self, other, loc=loc, ip=ip) - else: - return arith.remui(self, other, loc=loc, ip=ip) - - @_binary_op - def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__add__(self, loc=loc, ip=ip) - - @_binary_op - def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__sub__(self, loc=loc, ip=ip) - - @_binary_op - def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__mul__(self, loc=loc, ip=ip) - - @_binary_op - def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__truediv__(self, loc=loc, ip=ip) - - @_binary_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__floordiv__(self, loc=loc, ip=ip) - - @_binary_op - def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__mod__(self, loc=loc, ip=ip) - - # Comparison operators (comparison doesn't have right-hand-side variants) - @_dispatch_to_rhs_r_op - @_binary_op - def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip) - elif self.signed != False: - return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip) - elif self.signed != False: - return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - # In Python, bool(float("nan")) is True, so use unordered comparison here - return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip) - elif self.signed != False: - return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.is_float: - return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip) - elif self.signed != False: - return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip) - else: - return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) - - # Unary operators - def __invert__(self, *, loc=None, ip=None) -> "ArithValue": - return arith.xori(self, arith.constant(self.type, -1)) - - # Bitwise operations - @_dispatch_to_rhs_r_op - @_binary_op - def __and__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.andi(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __or__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.ori(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.xori(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.signed != False: - return arith.shrsi(self, other, loc=loc, ip=ip) - else: - return arith.shrui(self, other, loc=loc, ip=ip) - - @_dispatch_to_rhs_r_op - @_binary_op - def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.shli(self, other, loc=loc, ip=ip) - - @_binary_op - def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.andi(other, self, loc=loc, ip=ip) - - @_binary_op - def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.ori(other, self, loc=loc, ip=ip) - - @_binary_op - def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue": - return arith.xori(other, self, loc=loc, ip=ip) - - @_binary_op - def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__rshift__(self, loc=loc, ip=ip) - - @_binary_op - def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - return other.__lshift__(self, loc=loc, ip=ip) - - def __hash__(self): - return super().__hash__() - - def __str__(self): - return "?" - - def __repr__(self): - return self.__str__() - - -def _min(lhs, rhs, *, loc=None, ip=None): - """ - This function provides a unified interface for building arith min - - Assuming the operands have the same type - """ - from ..dsl import is_dynamic_expression - - if not is_dynamic_expression(lhs): - if not is_dynamic_expression(rhs): - return min(lhs, rhs) - else: - lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) - else: - if not is_dynamic_expression(rhs): - rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) - - if arith._is_integer_like_type(lhs.type): - if lhs.signed != False: - return arith.minsi(lhs, rhs, loc=loc, ip=ip) - else: - return arith.minui(lhs, rhs, loc=loc, ip=ip) - else: - return arith.minimumf(lhs, rhs, loc=loc, ip=ip) - - -def _max(lhs, rhs, *, loc=None, ip=None): - """ - This function provides a unified interface for building arith max - - Assuming the operands have the same type - """ - from ..dsl import is_dynamic_expression - - if not is_dynamic_expression(lhs): - if not is_dynamic_expression(rhs): - return max(lhs, rhs) - else: - lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) - else: - if not is_dynamic_expression(rhs): - rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) - - if arith._is_integer_like_type(lhs.type): - if lhs.signed != False: - return arith.maxsi(lhs, rhs, loc=loc, ip=ip) - else: - return arith.maxui(lhs, rhs, loc=loc, ip=ip) - else: - return arith.maximumf(lhs, rhs, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py deleted file mode 100644 index a0b0d0500824f3c5ffc9ae51c7218f40c64b780c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides MLIR GPU Dialect helper functions -""" - - -from ..._mlir import ir -from ..._mlir.dialects import gpu, arith, scf -from ..._mlir.extras import types as T - -from ..common import * - -# ============================================================================= -# GPU Dialect Helper functions -# ============================================================================= - - -def create_async_token(): - token_ty = gpu.AsyncTokenType.get() - token = gpu.wait(token_ty, []) - return token - - -def printf(fmt, *args, threadNumber=-1): - """Generate gpu.printf OP predicated on threadNumber""" - type_formats = [] - for arg in args: - ty_format = None - if ir.IndexType.isinstance(arg.type): - ty_format = "%llu" - if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - if width == 64: - ty_format = "%llu" - elif width == 32: - ty_format = "%d" - elif width == 1: - ty_format = "%i" - if ir.F32Type.isinstance(arg.type): - ty_format = "%f" - if ty_format is None: - raise DSLNotImplemented(arg.type) - type_formats.append(ty_format) - if threadNumber == -1: - gpu.printf(fmt.format(*type_formats) + "\n", args) - if threadNumber != -1: - tidx = gpu.thread_id(gpu.Dimension.x) - predicate = arith.cmpi( - arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber) - ) - if_op = scf.IfOp(predicate) - with ir.InsertionPoint(if_op.then_block): - gpu.printf(fmt.format(*type_formats) + "\n", args) - scf.yield_([]) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py deleted file mode 100644 index 57d717b42f94cfab678e70eceb5cc4d30dd10a45..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides @lru_cache_ir -It extends functools.lru_cache with IR Context awareness. - -Example usage: -from cutlass import ir -from lru_cache_ir import lru_cache_ir - -@lru_cache_ir(ir, maxsize=128, typed=False) -def make_layout(...): -... - -""" - - -from functools import lru_cache, wraps - -from ..._mlir import ir # type: ignore - - -def get_ir_context(func): - """ - Return the context for given func called under ir. - Currently the context includes MLIRContext and InsertionPoint. - """ - try: - if ir: - return (ir.Context.current, ir.InsertionPoint.current) - else: - return None - except ValueError: - return None - - -def lru_cache_ir(maxsize=128, typed=True): - """ - Applies an LRU cache to a given function, with awareness of IR context. - - Usage is similar to functools.lru_cache while taking `ir` as required argument. - - :param ir: The IR object from which to derive the context by `get_ir_context` - :param maxsize: Max cache size, same as functools.lru_cache - :param typed: Whether params are type-sensitive, default to True as IR is type-sensitive - """ - - def decorator(func): - # Use functools.lru_cache with a custom wrapper to control the key generation - @lru_cache(maxsize=maxsize, typed=typed) - def cached_func(context, *args, **kwargs): - return func(*args, **kwargs) - - @wraps(func) - def wrapper(*args, **kwargs): - try: - # Call the cached function with the context - return cached_func(get_ir_context(func), *args, **kwargs) - except (RuntimeError, TypeError): - return func(*args, **kwargs) - - # Expose cache-related methods for introspection - wrapper.cache_clear = cached_func.cache_clear - wrapper.cache_info = cached_func.cache_info - return wrapper - - return decorator diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py deleted file mode 100644 index 3989c75e5462d11d5ca229b757f4e5b45c7ee013..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides MLIR's OP helper functions -""" - - -import inspect -from functools import wraps - -from ..._mlir import ir - - -def dsl_user_op(opFunc): - @wraps(opFunc) - def wrapper(*args, **kwargs): - loc = kwargs.pop("loc", None) - if loc is None: - frame = inspect.currentframe().f_back - file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0) - loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc) - res_or_list = opFunc(*args, **kwargs, loc=loc) - return res_or_list - - return wrapper diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py deleted file mode 100644 index 7b11474c6b5b4fd30fb1feb6fae792fc9e059686..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py +++ /dev/null @@ -1,616 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides helper functions that are generated by the preprocessor. -The preprocessor read through python's ast and changes the input code. -""" - -from typing import Callable, Iterator, Optional, overload -from typing_extensions import deprecated -import warnings -import inspect -from types import BuiltinFunctionType -from functools import lru_cache -from inspect import getmembers - -from .utils.logger import log -from .common import * - -from ._mlir_helpers.arith import ArithValue - - -class Executor: - """ - The Executor class handles dynamic and compile-time (constexpr) execution - of "for" loops and "if-else-elif" statements. - - Methods: - set_functions: Assigns the functions for checking loop bounds and - conditional evaluation. - - for_execute: Generates MLIR for OP - while_execute: Generates MLIR while OP - if_execute: generate MLIR if OP - """ - - def __init__(self): - self._is_dynamic_expression = None - self._loop_execute_range_dynamic = None - self._if_dynamic = None - self._while_dynamic = None - self._compare_executor = None - self._any_executor = None - self._all_executor = None - self._builtin_redirector = None - - def set_functions( - self, - *, - is_dynamic_expression: Callable, - loop_execute_range_dynamic: Callable, - if_dynamic: Callable, - while_dynamic: Callable, - compare_executor: Callable, - any_executor: Callable = None, - all_executor: Callable = None, - builtin_redirector: Callable = None, - ): - self._is_dynamic_expression = is_dynamic_expression - self._loop_execute_range_dynamic = loop_execute_range_dynamic - self._if_dynamic = if_dynamic - self._while_dynamic = while_dynamic - self._compare_executor = compare_executor - self._any_executor = any_executor - self._all_executor = all_executor - self._builtin_redirector = builtin_redirector - - @staticmethod - def convert_to_list(x): - """This function is used to convert x to a list. - If x is None, return an empty list. - If x is not a list, return a list containing x. - Otherwise, return x itself. - """ - if x is None: - return [] - if not isinstance(x, list): - return [x] - return x - - @staticmethod - def converge_ret_val(res): - """This function is used to converge res (the return value) of the function. - If res is None, return None. - If res is a list and has only one element, return the element. - Otherwise, return res itself. - """ - if res is None: - return res - elif isinstance(res, list) and len(res) == 1: - return res[0] - return res - - def for_execute( - self, - func, - start, - stop, - step, - write_args=[], - full_write_args_count=0, - write_args_names=[], - unroll=-1, - unroll_full=False, - prefetch_stages=None, - ): - assert ( - self._loop_execute_range_dynamic - ), "Functions must be set before execution." - log().debug("start [%s] stop [%s] step [%s]", start, stop, step) - - return self._loop_execute_range_dynamic( - func, - start, - stop, - step, - write_args, - full_write_args_count, - write_args_names, - unroll, - unroll_full, - prefetch_stages, - ) - - def if_execute( - self, - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], - ): - assert self._if_dynamic, "Functions must be set before execution." - - # MLIR generation - return self._if_dynamic( - pred, - then_block, - else_block, - write_args, - full_write_args_count, - write_args_names, - ) - - def while_execute( - self, - pred, - while_before_block: Callable, - while_after_block: Callable, - write_args=[], - full_write_args_count=0, - write_args_names=[], - ): - assert self._while_dynamic, "Functions must be set before execution." - - # MLIR generation - return self._while_dynamic( - while_before_block, - while_after_block, - write_args, - full_write_args_count, - write_args_names, - ) - - -# ============================================================================= -# Decorator -# ============================================================================= - -executor = Executor() - - -def loop_selector( - start, - stop, - step, - *, - write_args=[], - full_write_args_count=0, - write_args_names=[], - unroll=-1, - unroll_full=False, - prefetch_stages=None, -): - log().debug( - "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]", - start, - stop, - step, - write_args, - full_write_args_count, - write_args_names, - unroll, - unroll_full, - prefetch_stages, - ) - from .typing import Integer, Numeric - - def _maybe_upcast(value): - if isinstance(value, Integer): - value = value.ir_value() - - return value - - start = _maybe_upcast(start) - stop = _maybe_upcast(stop) - step = _maybe_upcast(step) - - def ir_loop(func): - return executor.for_execute( - func, - start, - stop, - step, - write_args, - full_write_args_count, - write_args_names, - unroll, - unroll_full, - prefetch_stages, - ) - - return ir_loop - - -def if_selector(pred, write_args=[]): - log().debug("pred [%s] write_args [%s]", pred, write_args) - # Handle Numeric types here? - - from .typing import Numeric - - if isinstance(pred, Numeric): - pred = pred.value - - def ir_loop(func): - return func(pred, *write_args) - - return ir_loop - - -def while_selector(pred, write_args=[]): - def ir_while_loop(func): - return func(pred, *write_args) - - return ir_while_loop - - -def while_executor( - pred, - while_before_block: Callable, - while_after_block: Callable, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): - return executor.while_execute( - pred, - while_before_block, - while_after_block, - write_args, - full_write_args_count, - write_args_names, - ) - - -def if_executor( - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): - return executor.if_execute( - pred, - then_block, - else_block, - write_args, - full_write_args_count, - write_args_names, - ) - - -# ============================================================================= -# Range -# ============================================================================= - - -class range: - """ - A range-like object for dynamic loop iteration in the DSL. - - This class provides a range interface similar to Python's built-in range, - but is designed to be preprocessed into constructs for dynamic - loop execution. - - The class supports both single-argument (stop) and three-argument - (start, stop, step) constructors with additional parameters for loop - optimization: - - - unroll: Number of iterations to unroll (0 or 1 = no unrolling) - - unroll_full: Whether to fully unroll the loop - - prefetch_stages: Number of prefetch stages to generate - """ - - @overload - def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None): - pass - - @overload - def __new__( - cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None - ): - pass - - def __new__(cls, *args, **kwargs): - raise DSLRuntimeError("dynamic range should be always preprocessed to IR") - - def __iter__(self) -> Iterator[int]: - raise DSLRuntimeError("dynamic range should be always preprocessed to IR") - - -@deprecated( - "range_dynamic is deprecated and will be removed in the future, please remove it." -) -def range_dynamic(*args, **kwargs): - raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") - - -def range_constexpr(*args): - raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") - - -# ============================================================================= -# If expressions -# ============================================================================= - - -def const_expr(expression): - """ - This function is used to check if the expression is a python value. - If the expression is a python value, return the boolean value of the expression. - If the expression is a dynamic expression, raise an error. - """ - from .typing import Numeric - - failed = False - - if isinstance(expression, Numeric): - if isinstance(expression.value, (int, float, bool)): - return expression.value - else: - failed = True - elif executor._is_dynamic_expression(expression): - failed = True - - if failed: - raise DSLRuntimeError( - f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", - context={ - "If your expression depends on dynamic values": "Remove `const_expr()`", - }, - ) - return expression - - -@deprecated( - "dynamic_expr is deprecated and will be removed in the future, please remove it." -) -def dynamic_expr(expression): - return expression - - -# ============================================================================= -# Assertion & casting -# ============================================================================= - - -def assert_executor(test, msg=None): - from .typing import Numeric - - fail = False - # Implicit convert dynamic expression to bool is not allowed - # So here explicitly do a None check - if test is not None and executor._is_dynamic_expression(test): - if isinstance(test, Numeric): - try: - test = test.to(bool) - except: - fail = True - else: - fail = True - - if not fail: - assert test, msg - else: - raise DSLRuntimeError( - "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", - suggestion="Please replace with runtime assert.", - ) - - -def bool_cast(value): - if executor._is_dynamic_expression(value): - raise DSLRuntimeError( - "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", - suggestion="Please explicitly convert to boolean with expressions like comparision.", - ) - return bool(value) - - -def compare_executor(left, comparators, ops): - """ - Executes comparison operations with a left operand and a list of comparators. - - Args: - left: The leftmost value in the comparison chain - comparators: A list of values to compare against - ops: A list of comparison operators to apply - - Returns: - The result of the comparison chain - - Raises: - AssertionError: If the executor function is not set before execution - """ - assert ( - executor._compare_executor is not None - ), "Function must be set before execution." - return executor._compare_executor(left, comparators, ops) - - -def any_executor(iterable): - """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. - - :param iterable: An iterable to check if any elements evaluate to True - :type iterable: Iterable - :return: boolean of Python value or IR value - :rtype: bool or cutlass.Boolean - - """ - if executor._any_executor and executor._is_dynamic_expression(iterable): - return executor._any_executor(iterable) - else: - return any(iterable) - - -def all_executor(iterable): - """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. - - :param iterable: An iterable to check if all elements evaluate to True - :type iterable: Iterable - :return: boolean of Python value or IR value - :rtype: bool or cutlass.Boolean - """ - if executor._all_executor and executor._is_dynamic_expression(iterable): - return executor._all_executor(iterable) - else: - return all(iterable) - - -# ============================================================================= -# Control flow checks -# ============================================================================= -class DSLOptimizationWarning(Warning): - """ - This warning is used to warn the user about the optimization related issues in DSL. - """ - - def __init__(self, message): - self.message = message - super().__init__() - - def __str__(self): - return self.message - - -def range_value_check(*args): - """ - Ensure all `range_constexpr` bounds are compile-time constants (Python ints). - """ - try: - args = tuple(arg.__index__() for arg in args) - - # Compute range size and warn if it's too large - start = 0 - end = 0 - step = 1 - if len(args) == 1: - end = args[0] - elif len(args) == 2: - start = args[0] - end = args[1] - elif len(args) == 3: - start = args[0] - end = args[1] - step = args[2] - - range_length = (abs(end - start) - 1) // abs(step) + 1 - if range_length >= 64: - warnings.warn( - f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", - category=DSLOptimizationWarning, - stacklevel=2, - ) - - return (start, end, step) - except: - raise DSLRuntimeError( - "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", - suggestion="Use `range` instead of `range_constexpr`.", - ) - - -def range_perf_warning(filename, lineno, *args): - has_dynamic_expr = False - for arg in args: - if executor._is_dynamic_expression(arg): - has_dynamic_expr = True - break - if not has_dynamic_expr: - warnings.warn_explicit( - ( - "This loop is no longer unrolled and may cause performance regression. " - "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." - ), - category=DSLOptimizationWarning, - filename=filename, - lineno=lineno, - ) - - -@lru_cache(maxsize=1) -def _get_self_module(): - """ - This function is used to get the owning module of this function. - """ - return inspect.getmodule(_get_self_module) - - -def cf_symbol_check(symbol): - """ - Check if the symbol is control flow symbol from current module. - """ - - failed = False - name = symbol.__name__ - self_module = _get_self_module() - if inspect.ismodule(symbol): - name = "range" - if not self_module.__name__.startswith(symbol.__name__): - failed = True - else: - owning_module = inspect.getmodule(symbol) - if owning_module != self_module: - failed = True - - if failed: - raise DSLRuntimeError( - f"Incorrect {symbol.__name__} is used.", - suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.", - ) - - -def redirect_builtin_function(fcn): - """ - This function is used to redirect built-in function call - to the function defined in DSL package. - """ - # Only redirect if it's a built-in - if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: - return executor._builtin_redirector(fcn) - return fcn - - -def copy_members(dest, src): - """ - Copies all non-callable, non-dunder members from src to dest if they exist in src. - Skips members that are callables or have names starting with double underscores. - """ - if id(dest) == id(src): - return - - members = getmembers(dest) - for name, value in members: - if ( - name.startswith("__") - or isinstance(value, Callable) - or not hasattr(src, name) - ): - continue - setattr(dest, name, getattr(src, name)) - - -def get_locals_or_none(locals, symbols): - """ - Given a locals() dictionary and a list of symbol names, return a list of their values - in the same order as the symbols list. If a symbol is not present in locals, None is returned - for that symbol. - """ - variables = [] - for symbol in symbols: - if symbol in locals: - variables.append(locals[symbol]) - else: - variables.append(None) - return variables diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py deleted file mode 100644 index 11f2d1ae84405a13f7fffd241c6e6bdd6e167010..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ /dev/null @@ -1,1958 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor. -It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`. - -The preprocessor operates on the following constructs: - - `for` loops: - - Rewrites `for` loops with the `@loop_selector` decorator. - - Supports `range`, `range_dynamic` for loop iteration. - - `if-elif-else` statements: - - Rewrites conditional statements with the `@if_selector` decorator. - - Supports `dynamic_expr` and `const_expr` in the condition expressions. - -Additionally, both `for` loops and `if-else` statements require `yield` -operation generation. The preprocessor handles this by: - - Using a `ScopeManager` to track symbols across different scopes during AST traversal. - - Identifying read-only, read-write, and active variables for DSL constructs. - - Generating `yield` operations for symbols that are classified as read-write or write. - -It is designed to be generic and can handle `for` and `if` constructs from other dialects. -In such cases, the user's DSL should implement `@loop_selector` and `@if_selector` -to generate dialect-specific operations for `for` and `if` statements. -""" - -import ast -import importlib -import inspect -import textwrap -import warnings -from dataclasses import dataclass -from typing import List, Set, Dict, Any, Callable, Optional -from types import ModuleType -from collections import OrderedDict -from copy import deepcopy - -from .common import * -from .utils.logger import log - - -class OrderedSet: - """ - A deterministic set implementation for ordered operations. - """ - - def __init__(self, iterable=None): - self._dict = dict.fromkeys(iterable or []) - - def add(self, item): - self._dict[item] = None - - def __iter__(self): - return iter(self._dict) - - def __and__(self, other): - return OrderedSet(key for key in self._dict if key in other) - - def __or__(self, other): - new_dict = self._dict.copy() - new_dict.update(dict.fromkeys(other)) - return OrderedSet(new_dict) - - def __sub__(self, other): - return OrderedSet(key for key in self._dict if key not in other) - - def intersections(self, others): - """Compute the intersection of this set with multiple other sets. - - :param others: A list of sets to compute intersections with - :type others: List[Set[str]] - :return: A new ordered set containing elements that appear in this set - and at least one of the other sets - """ - result = OrderedSet() - for key in self._dict: - for other in reversed(others): - if key in other: - result.add(key) - break - return result - - -@dataclass -class ImportInfo: - """ - Information about an import expression. - """ - module_path: str - attr_name: Optional[str] - alias_name: str - - -@dataclass -class ScopeManager: - """ - Manages symbol scopes during AST traversal. - Manage nested scopes during transformations. - """ - - scopes: List[Set[str]] - - @classmethod - def create(cls) -> "ScopeManager": - return cls([]) - - def add_to_scope(self, name: str) -> None: - if name == "_": - return - self.scopes[-1].add(name) - - def get_active_symbols(self) -> List[Set[str]]: - return self.scopes.copy() - - def __enter__(self) -> "ScopeManager": - self.scopes.append(set()) - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.scopes.pop() - - -class DSLPreprocessor(ast.NodeTransformer): - """ - A preprocessor for transforming Python ASTs. It supports: - - - Rewriting `for` loops with the `@loop_selector` decorator. - - Rewriting `if-elif-else` statements with the `@if_selector` decorator. - - Generating `yield` operations for read-write or write symbols. - """ - - DECORATOR_FOR_STATEMENT = "loop_selector" - DECORATOR_IF_STATEMENT = "if_selector" - DECORATOR_WHILE_STATEMENT = "while_selector" - IF_EXECUTOR = "if_executor" - WHILE_EXECUTOR = "while_executor" - ASSERT_EXECUTOR = "assert_executor" - BOOL_CAST = "bool_cast" - IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType" - SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"} - COMPARE_EXECUTOR = "compare_executor" - ANY_EXECUTOR = "any_executor" - ALL_EXECUTOR = "all_executor" - - def __init__(self, client_module_name): - super().__init__() - self.counter = 0 # Unique function names for multiple loops - self.scope_manager = ScopeManager.create() - self.processed_functions = set() - self.function_counter = 0 - self.function_name = "" - self.class_name = None - self.file_name = "" - self.function_depth = 0 - self.local_closures = set() - self.function_globals = None - self.client_module_name = client_module_name - self.import_top_module = False - - def _create_module_attribute( - self, - func_name, - *, - top_module_name="_dsl_", - submodule_name="ast_helpers", - lineno=None, - col_offset=None, - ): - # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong. - def set_location(node, lineno, col_offset): - if lineno and col_offset: - node.lineno = lineno - node.end_lineno = lineno - node.col_offset = col_offset - node.end_col_offset = col_offset - - base = ast.Name(id=top_module_name, ctx=ast.Load()) - set_location(base, lineno, col_offset) - if submodule_name: - base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load()) - set_location(base, lineno, col_offset) - node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) - set_location(node, lineno, col_offset) - return node - - def _get_module_imports(self, decorated_func): - """Extract imports from the module containing the decorated function""" - imports = [] - - # Get the module containing the decorated function - if module := inspect.getmodule(decorated_func): - try: - # Get the module source code - source = inspect.getsource(module) - module_ast = ast.parse(source) - - # Extract imports from the full module - alias = lambda n: n.asname if n.asname else n.name - for node in ast.walk(module_ast): - if isinstance(node, ast.Import): - for name in node.names: - imports.append( - ImportInfo( - module_path=name.name, - attr_name=None, - alias_name=alias(name), - ) - ) - elif isinstance(node, ast.ImportFrom): - module_name = node.module - if node.level > 0: - # Handle relative imports - package_name = module.__package__.rsplit( - ".", node.level - 1 - )[0] - module_name = f"{package_name}.{module_name}" - for name in node.names: - imports.append( - ImportInfo( - module_path=module_name, - attr_name=name.name, - alias_name=alias(name), - ) - ) - except (IOError, TypeError): - pass - - return imports - - def exec(self, function_name, original_function, code_object, exec_globals): - # Get imports from the original module - module_imports = self._get_module_imports(original_function) - - # Import all required modules - for import_info in module_imports: - module_path, attr_name, alias_name = ( - import_info.module_path, - import_info.attr_name, - import_info.alias_name, - ) - try: - module = importlib.import_module(module_path) - if attr_name: - if attr_name == "*": - if hasattr(module, "__all__"): - attrs = module.__all__ - else: - attrs = [ - name for name in dir(module) if not name.startswith("_") - ] - else: - attrs = [attr_name] - - for attr in attrs: - alias = attr if attr_name == "*" else alias_name - exec_globals[alias] = getattr(module, attr) - else: - exec_globals[alias_name] = module - except (ImportError, AttributeError) as e: - raise ImportError(f"Failed to import {module_path}: {str(e)}") - - # Execute the transformed code - log().info( - "ASTPreprocessor Executing transformed code for function [%s]", - function_name, - ) - exec(code_object, exec_globals) - return exec_globals.get(function_name) - - @staticmethod - def print_ast(transformed_tree=None): - print("#", "-" * 40, "Transformed AST", "-" * 40) - unparsed_code = ast.unparse(transformed_tree) - print(unparsed_code) - print("#", "-" * 40, "End Transformed AST", "-" * 40) - - def make_func_param_name(self, base_name, used_names): - """Generate a unique parameter name that doesn't collide with existing names.""" - if base_name not in used_names: - return base_name - - i = 0 - while f"{base_name}_{i}" in used_names: - i += 1 - return f"{base_name}_{i}" - - def transform_function(self, func_name, function_pointer): - """ - Transforms a function. - """ - # Skip if the function has already been processed - if function_pointer in self.processed_functions: - log().info( - "ASTPreprocessor Skipping already processed function [%s]", func_name - ) - return [] - - # Step 1. Parse the given function - file_name = inspect.getsourcefile(function_pointer) - lines, start_line = inspect.getsourcelines(function_pointer) - dedented_source = textwrap.dedent("".join(lines)) - tree = ast.parse(dedented_source, filename=file_name) - # Bump the line numbers so they match the real source file - ast.increment_lineno(tree, start_line - 1) - - # Step 1.2 Check the decorator - if not self.check_decorator(tree.body[0]): - log().info( - "[%s] - Skipping function due to missing decorator", - func_name, - ) - return [] - - self.processed_functions.add(function_pointer) - log().info("ASTPreprocessor Transforming function [%s]", func_name) - - # Step 2. Transform the function - transformed_tree = self.visit(tree) - - # Step 3. Import cutlass and base_dsl - top_module_name = ".".join(self.client_module_name) - import_stmts = [] - if self.import_top_module: - import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)])) - import_stmts.append( - ast.Import( - names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")] - ) - ) - transformed_tree.body = import_stmts + transformed_tree.body - - # Step 4. Import cutlass and base_dsl - ast.fix_missing_locations(transformed_tree) - combined_body = transformed_tree.body - - # Step 5. Return the transformed tree - return combined_body - - def check_early_exit(self, tree, kind): - """ - Checks if a given region or scope in the provided Python code has early exits. - """ - - class EarlyExitChecker(ast.NodeVisitor): - def __init__(self, kind): - self.has_early_exit = False - self.early_exit_node = None - self.early_exit_type = None - self.kind = kind - self.loop_nest_level = 0 - - # Early exit is not allowed in any level of dynamic control flow - def visit_Return(self, node): - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "return" - - def visit_Raise(self, node): - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "raise" - - def visit_Break(self, node): - # For break/continue in inner loops, we don't consider it as early exit - if self.loop_nest_level == 0 and self.kind != "if": - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "break" - - def visit_Continue(self, node): - if self.loop_nest_level == 0 and self.kind != "if": - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "continue" - - def visit_For(self, node): - self.loop_nest_level += 1 - self.generic_visit(node) - self.loop_nest_level -= 1 - - def visit_While(self, node): - self.loop_nest_level += 1 - self.generic_visit(node) - self.loop_nest_level -= 1 - - checker = EarlyExitChecker(kind) - checker.generic_visit(tree) - if not checker.has_early_exit: - return - raise DSLAstPreprocessorError( - message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`" - + (f" in `{self.class_name}`" if self.class_name else ""), - filename=self.file_name, - snippet=ast.unparse(tree), - suggestion=( - "If predicates are constant expression, write like " - "`if const_expr(...)` or `for ... in range_constexpr(...)`. " - "In that case, early exit will be executed by Python " - "interpreter, so it's supported." - ), - ) - - def is_node_constexpr(self, node) -> bool: - """ - Determines if the node is a constexpr. - Supported nodes are if, while statements. - """ - if isinstance(node, ast.If) or isinstance(node, ast.While): - if isinstance(node.test, ast.Call): - func = node.test.func - - if isinstance(func, ast.Attribute) and func.attr == "const_expr": - return True - - elif isinstance(func, ast.Name) and func.id == "const_expr": - return True - return False - - def _get_range_kind(self, iter_node): - """ - Return "range", "range_dynamic", "range_constexpr" or None for the iterable - """ - if isinstance(iter_node, ast.Call): - func = iter_node.func - if ( - isinstance(func, ast.Name) - and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS - ): - return func.id, True, len(iter_node.keywords) != 0 - if ( - isinstance(func, ast.Attribute) - and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS - ): - return func.attr, False, len(iter_node.keywords) != 0 - return None, None, None - - def transform(self, original_function, exec_globals): - """ - Transforms the provided function using the preprocessor. - """ - self.file_name = inspect.getsourcefile(original_function) - self.function_globals = exec_globals - transformed_tree = self.transform_function( - original_function.__name__, original_function - ) - self.function_globals = None - unified_tree = ast.Module(body=transformed_tree, type_ignores=[]) - unified_tree = ast.fix_missing_locations(unified_tree) - - return unified_tree - - def analyze_region_variables( - self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]] - ): - """ - Analyze variables in different code regions to identify read-only, write-only, - and active variables for DSL constructs. - """ - - # we need orderedset to keep the insertion order the same. otherwise generated IR is different each time - write_args = OrderedSet() - invoked_args = OrderedSet() - local_closure = self.local_closures - file_name = self.file_name - region_node = node - - class RegionAnalyzer(ast.NodeVisitor): - force_store = False - - def visit_Name(self, node): - """ - Mark every store as write. - """ - if isinstance(node.ctx, ast.Store) or self.force_store: - write_args.add(node.id) - - def visit_Subscript(self, node): - # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`. - # We need to force the store for the `Name` to be marked as write. - if isinstance(node.ctx, ast.Store): - self.force_store = True - self.visit(node.value) - self.force_store = False - self.visit(node.slice) - else: - self.generic_visit(node) - - def visit_Assign(self, node): - self.force_store = True - [self.visit(target) for target in node.targets] - self.force_store = False - self.visit(node.value) - - def visit_AugAssign(self, node): - self.force_store = True - self.visit(node.target) - self.force_store = False - self.visit(node.value) - - @staticmethod - def get_call_base(func_node): - if isinstance(func_node, ast.Attribute): - # If the .value is another Attribute, keep digging - if isinstance(func_node.value, ast.Attribute): - return RegionAnalyzer.get_call_base(func_node.value) - # If the .value is a Name, that's our base - elif isinstance(func_node.value, ast.Name): - return func_node.value.id - else: - # Could be something else (lambda, call, etc.) - return None - elif isinstance(func_node, ast.Name): - return None - return None - - @staticmethod - def get_function_name(func_node: ast.Call): - if isinstance(func_node.func, ast.Name): - function_name = func_node.func.id - # Check if it's a method or attribute call - elif isinstance(func_node.func, ast.Attribute): - function_name = func_node.func.attr - else: - function_name = None - return function_name - - def visit_Call(self, node): - base_name = RegionAnalyzer.get_call_base(node.func) - - if isinstance(node.func, ast.Name): - func_name = node.func.id - if func_name in local_closure: - raise DSLAstPreprocessorError( - f"Function `{func_name}` is a closure and is not supported in for/if statements", - filename=file_name, - snippet=ast.unparse(region_node), - ) - - # Classes are mutable by default. Mark them as write. If they are - # dataclass(frozen=True), treat them as read in runtime. - if base_name is not None and base_name not in ("self"): - invoked_args.add(base_name) - - self.generic_visit(node) - - analyzer = RegionAnalyzer() - analyzer.visit(ast.Module(body=node)) - - # If arg is both write and invoke, remove from invoked_args - invoked_args = invoked_args - write_args - - write_args = list(write_args.intersections(active_symbols)) - invoked_args = list(invoked_args.intersections(active_symbols)) - - return write_args + invoked_args, len(write_args) - - def extract_range_args(self, iter_node): - args = iter_node.args - if len(args) == 1: - return ( - self.visit(ast.Constant(value=0)), - self.visit(args[0]), - self.visit(ast.Constant(value=1)), - False, - ) - elif len(args) == 2: - return ( - self.visit(args[0]), - self.visit(args[1]), - self.visit(ast.Constant(value=1)), - False, - ) - elif len(args) == 3: - return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True - else: - raise DSLAstPreprocessorError( - "Unsupported number of arguments in range", filename=self.file_name - ) - - def extract_unroll_args(self, iter_node): - keywords = {kw.arg: kw.value for kw in iter_node.keywords} - return ( - keywords.get("unroll", ast.Constant(value=-1)), - keywords.get("unroll_full", ast.Constant(value=False)), - ) - - def issue_deprecation_warning(self, *, message, category, filename, lineno): - warnings.simplefilter("always", category) # turn off filter - warnings.warn_explicit( - message, category=category, filename=filename, lineno=lineno - ) - warnings.simplefilter("default", category) # reset filter - - def extract_prefetch_stages_args(self, iter_node): - keywords = {kw.arg: kw.value for kw in iter_node.keywords} - if "pipelining" in keywords: - self.issue_deprecation_warning( - message="pipelining is deprecated, use prefetch_stages instead", - category=DeprecationWarning, - filename=self.file_name, - lineno=iter_node.lineno, - ) - return keywords.get("pipelining", ast.Constant(value=None)) - return keywords.get("prefetch_stages", ast.Constant(value=None)) - - def create_loop_function( - self, - func_name, - node, - start, - stop, - step, - unroll, - unroll_full, - prefetch_stages, - write_args, - full_write_args_count, - ): - """ - Creates a loop body function with the `loop_selector` decorator. - """ - - func_args = [ast.arg(arg=node.target.id, annotation=None)] - func_args += [ast.arg(arg=var, annotation=None) for var in write_args] - - # Create the loop body - transformed_body = [] - for stmt in node.body: - transformed_stmt = self.visit(stmt) # Recursively visit inner statements - if isinstance(transformed_stmt, list): - transformed_body.extend(transformed_stmt) - else: - transformed_body.append(transformed_stmt) - - # Handle the return for a single iterated argument correctly - if len(write_args) == 0: - transformed_body.append(ast.Return()) - else: - transformed_body.append( - ast.Return( - value=ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], - ctx=ast.Load(), - ) - ) - ) - - # Define the decorator with parameters - decorator = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.DECORATOR_FOR_STATEMENT, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[start, stop, step], - keywords=[ - ast.keyword(arg="unroll", value=unroll), - ast.keyword(arg="unroll_full", value=unroll_full), - ast.keyword(arg="prefetch_stages", value=prefetch_stages), - ast.keyword( - arg="write_args", - value=self.generate_get_locals_or_none_call(write_args), - ), - ast.keyword( - arg="full_write_args_count", - value=ast.Constant(value=full_write_args_count), - ), - ast.keyword( - arg="write_args_names", - value=ast.List( - elts=[ast.Constant(value=arg) for arg in write_args], - ctx=ast.Load(), - ), - ), - ], - ), - node, - ) - - return ast.copy_location( - ast.FunctionDef( - name=func_name, - args=ast.arguments( - posonlyargs=[], - args=func_args, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=transformed_body, - decorator_list=[decorator], - ), - node, - ) - - def visit_BoolOp(self, node): - # Visit child nodes first - self.generic_visit(node) - - # It is necessary to expand short circuit evaluation explicit here - # Although we do not support inline if-else for IR generation, this is actually evaluated in Python - # So it's fine here - # Transform "and" to "and_" - if isinstance(node.op, ast.And): - # Create an if-else statement in AST form - # if type(lhs) == bool and lhs == False: - # return lhs - # else - # return and_(lhs, rhs) - short_circuit_value = ast.Constant(value=False) - helper_func = self._create_module_attribute( - "and_", - top_module_name="cutlass", - submodule_name=None, - lineno=node.lineno, - col_offset=node.col_offset, - ) - self.import_top_module = True - # Transform "or" to "or_" - elif isinstance(node.op, ast.Or): - # Create an if-else statement in AST form - # if type(lhs) == bool and lhs == True: - # return lhs - # else - # return or_(lhs, rhs) - short_circuit_value = ast.Constant(value=True) - helper_func = self._create_module_attribute( - "or_", - top_module_name="cutlass", - submodule_name=None, - lineno=node.lineno, - col_offset=node.col_offset, - ) - self.import_top_module = True - else: - # BoolOp should be either And or Or - raise DSLAstPreprocessorError( - f"Unsupported boolean operation: {node.op}", - filename=self.file_name, - snippet=ast.unparse(node), - ) - - def short_circuit_eval(value, short_circuit_value): - return ast.BoolOp( - op=ast.And(), - values=[ - ast.Compare( - left=ast.Call( - func=ast.Name(id="type", ctx=ast.Load()), - args=[value], - keywords=[], - ), - ops=[ast.Eq()], - comparators=[ast.Name(id="bool", ctx=ast.Load())], - ), - ast.Compare( - left=value, - ops=[ast.Eq()], - comparators=[short_circuit_value], - ), - ], - ) - - lhs = node.values[0] - - for i in range(1, len(node.values)): - test = short_circuit_eval(lhs, short_circuit_value) - lhs = ast.IfExp( - test=test, - body=lhs, - orelse=ast.Call( - func=helper_func, - args=[lhs, node.values[i]], - keywords=[], - ), - ) - - return ast.copy_location(lhs, node) - - def visit_UnaryOp(self, node): - # Visit child nodes first - self.generic_visit(node) - - # Transform "not" to "~" as we overload __invert__ - if isinstance(node.op, ast.Not): - func_name = self._create_module_attribute( - "not_", - top_module_name="cutlass", - submodule_name=None, - lineno=node.lineno, - col_offset=node.col_offset, - ) - self.import_top_module = True - return ast.copy_location( - ast.Call(func=func_name, args=[node.operand], keywords=[]), node - ) - - return node - - def _insert_range_value_check(self, node): - """ - Insert a check for range arguments - """ - range_inputs = node.iter.args - check_call = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - "range_value_check", lineno=node.lineno, col_offset=node.col_offset - ), - args=range_inputs, - keywords=[], - ), - node.iter, - ) - node.iter = ast.copy_location( - ast.Call( - func=ast.Name(id="range", ctx=ast.Load()), - args=[ast.Starred(value=check_call, ctx=ast.Load())], - keywords=[], - ), - node.iter, - ) - - def _insert_cf_symbol_check(self, func): - """ - Insert a check for range symbol - """ - check_call = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - "cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset - ), - args=[deepcopy(func)], - keywords=[], - ), - func, - ) - return ast.Expr(check_call) - - def visit_For(self, node): - # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. - range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) - if range_kind == "range_constexpr" or range_kind == None: - self.generic_visit(node) - if range_kind == "range_constexpr": - check_call = self._insert_cf_symbol_check(node.iter.func) - # Rewrite range_constexpr to range - node.iter.func = ast.Name(id="range", ctx=ast.Load()) - self._insert_range_value_check(node) - return [check_call, node] - return node - - active_symbols = self.scope_manager.get_active_symbols() - - with self.scope_manager: - if isinstance(node.target, ast.Name): - self.scope_manager.add_to_scope(node.target.id) - - if range_kind == "range_dynamic": - # Generate a warning - self.issue_deprecation_warning( - message="range_dynamic is deprecated and will be removed in the future, please remove it.", - category=DeprecationWarning, - filename=self.file_name, - lineno=node.iter.lineno, - ) - - warning_call = None - if range_kind == "range" and is_builtin_range and not has_keyword: - # Warn about possible performance regression due to behavior change - warning_call = ast.Expr( - ast.Call( - func=self._create_module_attribute( - "range_perf_warning", - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[ - ast.Constant(value=self.file_name), - ast.Constant(value=node.iter.lineno), - ] - + node.iter.args, - keywords=[], - ) - ) - ast.copy_location(warning_call, node.iter) - - is_prefixed_range = range_kind == "range" and not is_builtin_range - check_call = None - if range_kind == "range_dynamic" or is_prefixed_range: - # Insert a check for range symbol - if not is_prefixed_range: - check_call = self._insert_cf_symbol_check(node.iter.func) - else: - # Get toplevel module - check_call = self._insert_cf_symbol_check(node.iter.func.value) - - new_for_node = self.transform_for_loop(node, active_symbols) - if check_call is not None: - new_for_node = [check_call] + new_for_node - - return new_for_node if warning_call is None else [warning_call] + new_for_node - - @staticmethod - def _hoist_expr_to_assignments(expr, name): - return ast.copy_location( - ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr - ) - - def _build_select_and_assign(self, *, name, test, body, orelse, location): - node = ast.copy_location( - ast.Assign( - targets=[ast.Name(id=name, ctx=ast.Store())], - value=ast.IfExp( - test=test, - body=body, - orelse=orelse, - ), - ), - location, - ) - self.generic_visit(node) - return node - - def _handle_negative_step(self, node, start_expr, stop_expr, step_expr): - # hoist start, stop, step to assignments - start_ori_name = f"start_ori_{self.counter}" - start = self._hoist_expr_to_assignments(start_expr, start_ori_name) - stop_ori_name = f"stop_ori_{self.counter}" - stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name) - step_ori_name = f"step_ori_{self.counter}" - step = self._hoist_expr_to_assignments(step_expr, step_ori_name) - - extra_exprs = [start, stop, step] - - # Handle possible negative step, generates the following code in Python: - # isNegative = step < 0 - isNegative_name = f"isNegative_{self.counter}" - isNegative = ast.copy_location( - ast.Assign( - targets=[ast.Name(id=isNegative_name, ctx=ast.Store())], - value=ast.Compare( - left=ast.Name(id=step_ori_name, ctx=ast.Load()), - ops=[ast.Lt()], - comparators=[ast.Constant(value=0)], - ), - ), - step, - ) - - # start = stop if isNegative else start - start_name = f"start_{self.counter}" - start = self._build_select_and_assign( - name=start_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.Name(id=stop_ori_name, ctx=ast.Load()), - orelse=ast.Name(id=start_ori_name, ctx=ast.Load()), - location=start, - ) - - # stop = start if isNegative else stop - stop_name = f"stop_{self.counter}" - stop = self._build_select_and_assign( - name=stop_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.Name(id=start_ori_name, ctx=ast.Load()), - orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()), - location=stop, - ) - - # step = -step if isNegative else step - step_name = f"step_{self.counter}" - step = self._build_select_and_assign( - name=step_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.UnaryOp( - op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load()) - ), - orelse=ast.Name(id=step_ori_name, ctx=ast.Load()), - location=step, - ) - - # offset = start + stop if isNegative else 0 - offset_name = f"offset_{self.counter}" - offset = self._build_select_and_assign( - name=offset_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.BinOp( - op=ast.Add(), - left=ast.Name(id=start_name, ctx=ast.Load()), - right=ast.Name(id=stop_name, ctx=ast.Load()), - ), - orelse=ast.Constant(value=0), - location=node, - ) - - extra_exprs.append(isNegative) - extra_exprs.append(start) - extra_exprs.append(stop) - extra_exprs.append(step) - extra_exprs.append(offset) - - # Add this to begining of loop body - # for i in range(start, stop, step): - # i = offset - i if isNegative else i - assert isinstance(node.target, ast.Name) - - target_name = node.target.id - target = self._build_select_and_assign( - name=target_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.BinOp( - op=ast.Sub(), - left=ast.Name(id=offset_name, ctx=ast.Load()), - right=ast.Name(id=target_name, ctx=ast.Load()), - ), - orelse=ast.Name(id=target_name, ctx=ast.Load()), - location=node.target, - ) - - node.body.insert(0, target) - - return ( - ast.Name(id=start_name, ctx=ast.Load()), - ast.Name(id=stop_name, ctx=ast.Load()), - ast.Name(id=step_name, ctx=ast.Load()), - extra_exprs, - ) - - def transform_for_loop(self, node, active_symbols): - # Check for early exit and raise exception - self.check_early_exit(node, "for") - if node.orelse: - raise DSLAstPreprocessorError( - "dynamic for loop with else is not supported", - filename=self.file_name, - snippet=ast.unparse(node), - ) - - # Get loop target variable name - target_var_name = None - target_var_is_active_before_loop = False - if isinstance(node.target, ast.Name): - target_var_name = node.target.id - for active_symbol in active_symbols: - if target_var_name in active_symbol: - target_var_is_active_before_loop = True - active_symbols.remove(active_symbol) - break - - # Add necessary exprs to handle this - if target_var_is_active_before_loop: - # Initialize an extra loop carried variable - loop_carried_var_name = f"loop_carried_var_{self.counter}" - pre_loop_expr = ast.copy_location( - ast.Assign( - targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], - value=ast.Name(id=target_var_name, ctx=ast.Load()), - ), - node, - ) - # append an extra assignment to the loop carried variable - node.body.append( - ast.copy_location( - ast.Assign( - targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], - value=ast.Name(id=target_var_name, ctx=ast.Load()), - ), - node, - ) - ) - active_symbols.append({loop_carried_var_name}) - - start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) - unroll, unroll_full = self.extract_unroll_args(node.iter) - prefetch_stages = self.extract_prefetch_stages_args(node.iter) - write_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols - ) - - if has_step and self.client_module_name[0] == "cutlass": - start, stop, step, exprs = self._handle_negative_step( - node, start_expr, stop_expr, step_expr - ) - else: - start, stop, step, exprs = start_expr, stop_expr, step_expr, [] - - if target_var_is_active_before_loop: - exprs.append(pre_loop_expr) - - func_name = f"loop_body_{self.counter}" - self.counter += 1 - - func_def = self.create_loop_function( - func_name, - node, - start, - stop, - step, - unroll, - unroll_full, - prefetch_stages, - write_args, - full_write_args_count, - ) - - assign = self.create_cf_call(func_name, write_args, node) - - # This should work fine as it modifies the AST structure - exprs = exprs + [func_def] + assign - - if target_var_is_active_before_loop: - # Create a new assignment to the target variable - exprs.append( - ast.copy_location( - ast.Assign( - targets=[ast.Name(id=target_var_name, ctx=ast.Store())], - value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()), - ), - node, - ) - ) - - return exprs - - def visit_Assert(self, node): - test = self.visit(node.test) - - args = [ast.keyword(arg="test", value=test)] - if node.msg: - msg = self.visit(node.msg) - args.append(ast.keyword(arg="msg", value=msg)) - - # Rewrite to assert_executor(test, msg) - new_node = ast.Expr( - ast.Call( - func=self._create_module_attribute( - self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset - ), - args=[], - keywords=args, - ) - ) - - # Propagate line number from original node to new node - ast.copy_location(new_node, node) - return new_node - - def visit_Call(self, node): - func = node.func - # Visit args and kwargs - node.args = [self.visit(arg) for arg in node.args] - node.keywords = [self.visit(kwarg) for kwarg in node.keywords] - - # Rewrite call to some built-in functions - if isinstance(func, ast.Name): - # Check if the function is 'bool' - if func.id == "bool": - return ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.BOOL_CAST, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[node.args[0]], - keywords=[], - ), - node, - ) - elif func.id in ["any", "all"]: - helper_func = ( - self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR - ) - return ast.copy_location( - ast.Call( - func=self._create_module_attribute( - helper_func, lineno=node.lineno, col_offset=node.col_offset - ), - args=[node.args[0]], - keywords=[], - ), - node, - ) - elif func.id in ["min", "max"]: - return ast.copy_location( - ast.Call( - func=self._create_module_attribute( - func.id, - top_module_name="cutlass", - submodule_name=None, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[node.args[0], node.args[1]], - keywords=[], - ), - node, - ) - elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): - def create_downcast_call(arg): - return ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, - submodule_name="typing", - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[arg], - keywords=[], - ), - arg, - ) - module = self.function_globals.get(func.value.id) - if isinstance(module, ModuleType) and module.__package__.endswith( - "._mlir.dialects" - ): - # Check if argument is Numeric, if so, call ir_value() - args = [] - for arg in node.args: - args.append(create_downcast_call(arg)) - kwargs = [] - for kwarg in node.keywords: - kwargs.append( - ast.copy_location( - ast.keyword( - arg=kwarg.arg, - value=create_downcast_call(kwarg.value), - ), - kwarg, - ) - ) - return ast.copy_location( - ast.Call(func=func, args=args, keywords=kwargs), node - ) - else: - node.func = self.visit(node.func) - - return node - - def visit_ClassDef(self, node): - self.class_name = node.name - self.generic_visit(node) - self.class_name = None - return node - - def _visit_target(self, target): - if isinstance(target, ast.Name): - self.scope_manager.add_to_scope(target.id) - elif isinstance(target, ast.Tuple): - for t in target.elts: - if isinstance(t, ast.Name): - self.scope_manager.add_to_scope(t.id) - - def visit_Assign(self, node): - for target in node.targets: - self._visit_target(target) - self.generic_visit(node) - return node - - def visit_AugAssign(self, node): - self._visit_target(node.target) - self.generic_visit(node) - return node - - def visit_Name(self, node): - isLoad = isinstance(node.ctx, ast.Load) - if node.id in ["max", "min", "any", "all"] and isLoad: - return ast.copy_location( - ast.Call( - func=self._create_module_attribute( - "redirect_builtin_function", - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[node], - keywords=[], - ), - node, - ) - elif node.id == "_" and isLoad: - raise DSLAstPreprocessorError("Read '_' is not allowed") - else: - self.generic_visit(node) - return node - - def check_decorator(self, node: ast.AST) -> bool: - """ - Check if the function has the correct decorator for preprocessing. - """ - if not isinstance(node, ast.FunctionDef): - return False - decorator_list = node.decorator_list - if len(decorator_list) == 0: - return False - - for d in decorator_list: - if isinstance(d, ast.Call): - if isinstance(d.func, ast.Attribute): - if d.func.attr in ["jit", "kernel"]: - if d.keywords == []: - return True - for keyword in d.keywords: - if keyword.arg == "preprocess": - try: - if isinstance(keyword.value, ast.Constant): - return keyword.value.value - else: - return ast.literal_eval(keyword.value) - except: - pass - - elif isinstance(d, ast.Attribute): - if d.attr in ["jit", "kernel"]: - return True - - return False - - def remove_dsl_decorator(self, decorator_list): - """ - Remove .jit and .kernel decorators - The decorator can be in two forms: - - @jit(...) - - @jit - """ - new_decorator_list = [] - decorator_names = ["jit", "kernel"] - for d in decorator_list: - is_jit_or_kernel = False - if isinstance(d, ast.Call): - if isinstance(d.func, ast.Attribute): - if d.func.attr in decorator_names: - is_jit_or_kernel = True - elif isinstance(d, ast.Attribute): - if d.attr in decorator_names: - is_jit_or_kernel = True - - if not is_jit_or_kernel: - new_decorator_list.append(d) - return new_decorator_list - - def visit_FunctionDef(self, node): - with self.scope_manager: - self.function_counter += 1 - self.function_name = node.name - if self.function_depth > 0: - self.local_closures.add(node.name) - - self.function_depth += 1 - - # Add function name and arguments - self.scope_manager.add_to_scope(node.name) - for arg in node.args.args: - self.scope_manager.add_to_scope(arg.arg) - - self.generic_visit(node) - - self.function_depth -= 1 - - # Remove .jit and .kernel decorators - node.decorator_list = self.remove_dsl_decorator(node.decorator_list) - return node - - def visit_With(self, node): - with self.scope_manager: - for item in node.items: - if isinstance(item.optional_vars, ast.Name): - self.scope_manager.add_to_scope(item.optional_vars.id) - self.generic_visit(node) - - return node - - def visit_While(self, node): - # Constexpr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - check = self._insert_cf_symbol_check(node.test.func) - return [check, node] - - active_symbols = self.scope_manager.get_active_symbols() - - with self.scope_manager: - # Check for early exit and raise exception - self.check_early_exit(node, "while") - - write_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols - ) - func_name = f"while_region_{self.counter}" - self.counter += 1 - - func_def = self.create_while_function( - func_name, node, write_args, full_write_args_count - ) - assign = self.create_cf_call(func_name, write_args, node) - - return [func_def] + assign - - def visit_Try(self, node): - with self.scope_manager: - self.generic_visit(node) - return node - - def visit_ExceptHandler(self, node): - with self.scope_manager: - if node.name: # Exception variable - self.scope_manager.add_to_scope(node.name) - self.generic_visit(node) - return node - - def create_cf_call(self, func_name, yield_args, node): - """Creates the assignment statement for the if function call""" - if not yield_args: - return [ - ast.copy_location( - ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node - ) - ] - has_self = False - for i, arg in enumerate(yield_args): - if arg == "self": - has_self = True - yield_args[i] = "yield_self" - break - if len(yield_args) == 1: - assign = ast.Assign( - targets=[ast.Name(id=yield_args[0], ctx=ast.Store())], - value=ast.Name(id=func_name, ctx=ast.Load()), - ) - else: - assign = ast.Assign( - targets=[ - ast.Tuple( - elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args], - ctx=ast.Store(), - ) - ], - value=ast.Name(id=func_name, ctx=ast.Load()), - ) - - if has_self: - fix_self = ast.Expr( - value=ast.Call( - func=self._create_module_attribute( - "copy_members", lineno=node.lineno, col_offset=node.col_offset - ), - args=[ - ast.Name(id="self", ctx=ast.Load()), - ast.Name(id="yield_self", ctx=ast.Load()), - ], - keywords=[], - ) - ) - return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)] - else: - return [ast.copy_location(assign, node)] - - def visit_IfExp(self, node): - """ - Visits an inline if-else expression (ternary operator). - This is the Python equivalent of `x if condition else y`. - """ - self.generic_visit(node) - # Emit - # node if type(pred) == bool else select_(pred, body, orelse) - # so if pred is a python bool, use python to short-circuit and avoid emit arith.select - self.import_top_module = True - return ast.copy_location( - ast.IfExp( - test=ast.Compare( - left=ast.Call( - func=ast.Name(id="type", ctx=ast.Load()), - args=[node.test], - keywords=[], - ), - ops=[ast.Eq()], - comparators=[ast.Name(id="bool", ctx=ast.Load())], - ), - body=node, # Original ternary expression - orelse=ast.Call( - func=self._create_module_attribute( - "select_", top_module_name="cutlass", submodule_name=None - ), - args=[ - node.test, - node.body, - node.orelse, - ], - keywords=[], - ), - ), - node, - ) - - cmpops = { - "Eq": "==", - "NotEq": "!=", - "Lt": "<", - "LtE": "<=", - "Gt": ">", - "GtE": ">=", - "Is": "is", - "IsNot": "is not", - "In": "in", - "NotIn": "not in", - } - def compare_ops_to_str(self, node): - names = [ - ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops - ] - return ast.List(elts=names, ctx=ast.Load()) - - def visit_Compare(self, node): - self.generic_visit(node) - - comparator_strs = self.compare_ops_to_str(node) - - keywords = [ - ast.keyword(arg="left", value=node.left), - ast.keyword( - arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load()) - ), - ast.keyword(arg="ops", value=comparator_strs), - ] - - call = ast.copy_location( - ast.Call( - func=self._create_module_attribute(self.COMPARE_EXECUTOR), - args=[], - keywords=keywords, - ), - node, - ) - - return call - - def visit_If(self, node): - # const_expr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - check = self._insert_cf_symbol_check(node.test.func) - return [check, node] - - active_symbols = self.scope_manager.get_active_symbols() - with self.scope_manager: - # Check for early exit and raise exception - self.check_early_exit(node, "if") - - yield_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols - ) - func_name = f"if_region_{self.counter}" - self.counter += 1 - - func_def = self.create_if_function( - func_name, node, yield_args, full_write_args_count - ) - assign = self.create_cf_call(func_name, yield_args, node) - - return [func_def] + assign - - def generate_get_locals_or_none_call(self, write_args): - return ast.Call( - func=self._create_module_attribute("get_locals_or_none"), - args=[ - ast.Call( - func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[] - ), - ast.List( - elts=[ast.Constant(value=arg) for arg in write_args], - ctx=ast.Load(), - ), - ], - keywords=[], - ) - - def create_if_function(self, func_name, node, write_args, full_write_args_count): - test_expr = self.visit(node.test) - pred_name = self.make_func_param_name("pred", write_args) - func_args = [ast.arg(arg=pred_name, annotation=None)] - func_args += [ast.arg(arg=var, annotation=None) for var in write_args] - func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args] - - then_body = [] - for stmt in node.body: - transformed_stmt = self.visit(stmt) # Recursively visit inner statements - if isinstance(transformed_stmt, list): - then_body.extend(transformed_stmt) - else: - then_body.append(transformed_stmt) - - # Create common return list for all blocks - return_list = ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], - ctx=ast.Load(), - ) - - # Create common function arguments - func_decorator_arguments = ast.arguments( - posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[] - ) - func_then_else_arguments = ast.arguments( - posonlyargs=[], - args=func_args_then_else, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ) - - then_block_name = f"then_block_{self.counter}" - else_block_name = f"else_block_{self.counter}" - elif_region_name = f"elif_region_{self.counter}" - self.counter += 1 - - # Create then block - then_block = ast.copy_location( - ast.FunctionDef( - name=then_block_name, - args=func_then_else_arguments, - body=then_body + [ast.Return(value=return_list)], - decorator_list=[], - ), - node, - ) - - # Decorator keywords - decorator_keywords = [ - ast.keyword( - arg="pred", value=test_expr - ), # ast.Name(id="pred", ctx=ast.Load()) - ast.keyword( - arg="write_args", - value=self.generate_get_locals_or_none_call(write_args), - ), - ] - - # Create decorator - decorator = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.DECORATOR_IF_STATEMENT, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[], - keywords=decorator_keywords, - ), - node, - ) - - # Executor keywords - execute_keywords = [ - ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), - ast.keyword( - arg="write_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="full_write_args_count", - value=ast.Constant(value=full_write_args_count), - ), - ast.keyword( - arg="write_args_names", - value=ast.List( - elts=[ast.Constant(value=arg) for arg in write_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load()) - ), - ] - - # Handle different cases - if not write_args and node.orelse == []: - # No write_args case - only then_block needed - execute_call = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset - ), - args=[], - keywords=execute_keywords, - ), - node, - ) - func_body = [then_block, ast.Return(value=execute_call)] - else: - # Create else block based on node.orelse - if node.orelse: - if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If): - # Handle elif case - elif_node = node.orelse[0] - nested_if_name = elif_region_name - # Recursion for nested elif - nested_if = self.create_if_function( - nested_if_name, elif_node, write_args, full_write_args_count - ) - else_block = ast.FunctionDef( - name=else_block_name, - args=func_then_else_arguments, - body=[ - nested_if, - ast.Return( - value=ast.Name(id=nested_if_name, ctx=ast.Load()) - ), - ], - decorator_list=[], - ) - else: - - else_body = [] - for stmt in node.orelse: - transformed_stmt = self.visit( - stmt - ) # Recursively visit inner statements - if isinstance(transformed_stmt, list): - else_body.extend(transformed_stmt) - else: - else_body.append(transformed_stmt) - - # Regular else block - else_block = ast.FunctionDef( - name=else_block_name, - args=func_then_else_arguments, - body=else_body + [ast.Return(value=return_list)], - decorator_list=[], - ) - else: - # Default else block - else_block = ast.FunctionDef( - name=else_block_name, - args=func_then_else_arguments, - body=[ast.Return(value=return_list)], - decorator_list=[], - ) - - # Add else_block to execute keywords - execute_keywords.append( - ast.keyword( - arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load()) - ) - ) - - execute_call = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset - ), - args=[], - keywords=execute_keywords, - ), - node, - ) - func_body = [ - then_block, - ast.copy_location(else_block, node), - ast.Return(value=execute_call), - ] - - return ast.copy_location( - ast.FunctionDef( - name=func_name, - args=func_decorator_arguments, - body=func_body, - decorator_list=[decorator], - ), - node, - ) - - def create_while_function(self, func_name, node, write_args, full_write_args_count): - """Create a while function that looks like: - - @while_selector(pred, write_args=[]) - def while_region(pred, write_args): - def while_before_block(*write_args): - # Note that during eval of pred can possibly alter yield_args - return *pred, write_args - def while_after_block(*write_args): - ...loop_body_transformed... - return write_args - return self.while_executor(pred, write_args, - while_before_block, while_after_block, constexpr) - write_args = while_region(pred, write_args) - - Which will later be executed as psuedo-code: - - # Dynamic mode: - scf.WhileOp(types(write_args), write_args) - with InsertionPoint(before_block): - cond, write_args = while_before_block(*write_args) - scf.ConditionOp(cond, write_args) - with InsertionPoint(after_block): - write_args = while_after_block(write_args) - scf.YieldOp(write_args) - return while_op.results_ - - # Const mode: - cond, write_args = while_before_block(write_args) - while pred: - write_args = body_block(write_args) - cond, write_args = while_before_block(write_args) - return write_args - """ - test_expr = self.visit(node.test) - pred_name = self.make_func_param_name("pred", write_args) - - # Section: decorator construction - decorator_keywords = [ - ast.keyword(arg="pred", value=test_expr), - ast.keyword( - arg="write_args", - value=self.generate_get_locals_or_none_call(write_args), - ), - ] - decorator = ast.copy_location( - ast.Call( - func=self._create_module_attribute( - self.DECORATOR_WHILE_STATEMENT, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[], - keywords=decorator_keywords, - ), - node, - ) - - # Section: Shared initialization for before and after blocks - while_before_block_name = f"while_before_block_{self.counter}" - while_after_block_name = f"while_after_block_{self.counter}" - self.counter += 1 - block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args] - block_args = ast.arguments( - posonlyargs=[], - args=block_args_args, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ) - - yield_args_ast_name_list = ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], - ctx=ast.Load(), - ) - - # Section: while_before_block FunctionDef, which contains condition - while_before_return_list = ast.List( - elts=[test_expr, yield_args_ast_name_list], - ctx=ast.Load(), - ) - while_before_stmts = [ast.Return(value=while_before_return_list)] - while_before_block = ast.copy_location( - ast.FunctionDef( - name=while_before_block_name, - args=block_args, - body=while_before_stmts, - decorator_list=[], - ), - test_expr, - ) - - # Section: while_after_block FunctionDef, which contains loop body - while_after_stmts = [] - for stmt in node.body: - transformed_stmt = self.visit(stmt) # Recursively visit inner statements - if isinstance(transformed_stmt, list): - while_after_stmts.extend(transformed_stmt) - else: - while_after_stmts.append(transformed_stmt) - while_after_stmts.append(ast.Return(value=yield_args_ast_name_list)) - - while_after_block = ast.copy_location( - ast.FunctionDef( - name=while_after_block_name, - args=block_args, - body=while_after_stmts, - decorator_list=[], - ), - node, - ) - - # Section: Execute via executor - execute_keywords = [ - ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), - ast.keyword( - arg="write_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="full_write_args_count", - value=ast.Constant(value=full_write_args_count), - ), - ast.keyword( - arg="while_before_block", - value=ast.Name(id=while_before_block_name, ctx=ast.Load()), - ), - ast.keyword( - arg="while_after_block", - value=ast.Name(id=while_after_block_name, ctx=ast.Load()), - ), - ast.keyword( - arg="write_args_names", - value=ast.List( - elts=[ast.Constant(value=arg) for arg in write_args], - ctx=ast.Load(), - ), - ), - ] - - execute_call = ast.Call( - func=self._create_module_attribute( - self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset - ), - args=[], - keywords=execute_keywords, - ) - - # Putting everything together, FunctionDef for while_region - func_args_args = [ast.arg(arg=pred_name, annotation=None)] - func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args] - func_args = ast.arguments( - posonlyargs=[], - args=func_args_args, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ) - - return ast.copy_location( - ast.FunctionDef( - name=func_name, - args=func_args, - body=[ - while_before_block, - while_after_block, - ast.Return(value=execute_call), - ], - decorator_list=[decorator], - ), - node, - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py deleted file mode 100644 index 5d9234f2fe760ba0026a63c139b8535dd777f621..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py +++ /dev/null @@ -1,153 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides jit cache load/dump helper functions -""" - -import os -import uuid -import random -import tempfile -import pwd -import time -from pathlib import Path -import hashlib - -from .utils.logger import log -from .jit_executor import JitExecutor - -from .._mlir import ir - -# ============================================================================= -# Jit Cache Helper functions -# ============================================================================= - - -def get_current_user(): - # Try to get the user from the environment variable first - user = os.getenv("USER") or os.getenv("USERNAME") - if not user: - # Fallback for Unix-like systems - user = pwd.getpwuid(os.getuid()).pw_name - return user - - -try: - default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/" -except Exception as e: - # If all else fails, provide a default fallback path - default_generated_ir_path = "/tmp/cutlass_python_cache/" - print(f"Could not determine user, using default path. Error: {e}") - - -def load_ir(file, asBytecode=False): - """Load generated IR from a file.""" - assert "mlir" in file - func_name = file.split(".mlir")[0].split("dsl_")[-1] - with ir.Context() as ctx: - with open(file, "rb" if asBytecode else "r") as f: - module = ir.Module.parse(f.read()) - - return func_name, module - - -def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: - """Generate a unique filename with an optional new extension.""" - random_part = random.randint(0, 999999) - timestamp = time.time() - hash_input = f"{fpath}_{timestamp}_{random_part}".encode() - hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability - stem_with_hash = f"{fpath.stem}_{hash_code}" - return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix) - - -def save_ir( - dsl_name: str, - module: object, - fname: str, - isTemp: bool = False, - asBytecode: bool = False, -) -> str: - """Save generated IR to a file.""" - initial_name = f"{dsl_name.lower()}_{fname}.mlir" - save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd()) - save_fname = save_path / initial_name - # Random ID to avoid any collisions - rnd_id = str(uuid.uuid4()) - pid = os.getpid() - # use temp dir to be robust against program interruptions - temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}") - # If the process exits abnormally, may leave a temporary folder. Needs to be removed manually. - os.makedirs(temp_dir, exist_ok=False) - temp_fname = os.path.join(temp_dir, initial_name) - - if asBytecode: - with open(temp_fname, "wb") as f: - module.operation.write_bytecode(f) - else: - with open(temp_fname, "w") as f: - print(module, file=f) - # os.replace is guaranteed to be atomic on POSIX systems if it succeeds - # so filepath cannot see a partial write - os.replace(temp_fname, save_fname) - os.removedirs(temp_dir) - log().debug("Generated IR saved into %s", save_fname) - return save_fname - - -def check_func_name(jit_cache, func_name): - if not func_name in jit_cache: - jit_cache[func_name] = JitExecutor(None, None, None, None, None, None) - return jit_cache - - -def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): - """Load cache from a directory path.""" - if not os.path.exists(path): - return dict() - files = os.listdir(path) - jit_cache = dict() - try: - for idx, file in enumerate(files): - if idx >= int(cache_limit): - break - # identify dsl prefix - if not file.startswith(f"{dsl_name.lower()}"): - continue - if ".mlir" in file: - func_name, ir_module = load_ir( - os.path.join(path, file), asBytecode=True - ) - jit_cache = check_func_name(jit_cache, func_name) - jit_cache[func_name].ir_module = ir_module - except Exception as e: - print(f"{dsl_name} failed with loading generated IR cache.", e) - jit_cache = dict() - return jit_cache - - -def dump_cache_to_path( - dsl_name, jit_cache, cache_limit, path=default_generated_ir_path -): - log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) - os.makedirs(path, exist_ok=True) - original_path = os.getcwd() - try: - os.chdir(path) - for idx, [key, value] in enumerate(jit_cache.items()): - if idx >= int(cache_limit): - break - save_ir(dsl_name, value.ir_module, key, asBytecode=True) - except Exception as e: - print(f"{dsl_name} failed with caching generated IR", e) - finally: - os.chdir(original_path) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py deleted file mode 100644 index 3cf413ed5018f99ae748cb2eb1883992f27a87b9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py +++ /dev/null @@ -1,268 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import os -from typing import Any, Dict, Iterable, Optional, Union - -""" -This module provides a Exception classes DSL class for any Dialect. -""" - - -# Add color codes at the top of the file after imports -class Colors: - """ANSI color codes for error messages""" - - RED = "\033[91m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - GREEN = "\033[92m" - BOLD = "\033[1m" - RESET = "\033[0m" - - -# ============================================================================= -# DSL Exceptions -# ============================================================================= - - -class DSLBaseError(Exception): - """ - Base exception for DSL-related errors. - Provides optional contextual metadata to aid in debugging. - """ - - def __init__( - self, - message: str, - line: Optional[int] = None, - snippet: Optional[str] = None, - filename: Optional[str] = None, - error_code: Optional[Union[str, int]] = None, - context: Optional[Union[Dict[str, Any], str]] = None, - suggestion: Optional[str] = None, - cause: Optional[BaseException] = None, - ) -> None: - self.message = message - self.line = line - self.filename = filename - self.snippet = snippet - self.error_code = error_code - self.context = context - self.suggestion = suggestion - self.cause = cause - - super().__init__(self._format_message()) - - def _format_message(self): - """ - Formats the complete error message with available metadata. - Override this in subclasses if you want to change formatting logic. - """ - parts = [f"{self.__class__.__name__}: {self.message}"] - - if self.error_code is not None: - parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n") - - if self.line is not None: - parts.append(f" Line: {self.line}") - - if self.filename is not None: - parts.append(f" File: {self.filename}") - - if self.snippet: - # Optionally truncate long snippets for readability - parts.append(f" Snippet: \n {self.snippet}") - - if self.cause: - parts.append(f" Caused exception: {self.cause}") - - if self.context: - if isinstance(self.context, dict): - parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n") - for key, value in self.context.items(): - parts.append(f" {key}: {value}") - else: - parts.append( - f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}" - ) - - if self.suggestion: - parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}") - if isinstance(self.suggestion, (list, tuple)): - for suggestion in self.suggestion: - parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}") - else: - parts.append(f" {self.suggestion}") - - return "\n".join(parts) - - -class DSLRuntimeError(DSLBaseError): - """ - Raised when an error occurs during JIT-time code generation in the DSL. - """ - - # Inherits all logic from DSLBaseError; override methods if you need - # specialized behavior or formatting for runtime errors. - pass - - -def _get_friendly_cuda_error_message(error_code, error_name): - # Avoid circular dependency - from .runtime.cuda import get_device_info - - """Get a user-friendly error message for common CUDA errors.""" - # Strip the byte string markers if present - if isinstance(error_name, bytes): - error_name = error_name.decode("utf-8") - elif ( - isinstance(error_name, str) - and error_name.startswith("b'") - and error_name.endswith("'") - ): - error_name = error_name[2:-1] - - # Add target architecture info - target_arch = os.getenv("CUTE_DSL_ARCH", "unknown") - - error_messages = { - "CUDA_ERROR_INVALID_SOURCE": ( - f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n" - ), - "CUDA_ERROR_NO_BINARY_FOR_GPU": ( - f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n" - ), - "CUDA_ERROR_OUT_OF_MEMORY": ( - f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n" - ), - "CUDA_ERROR_INVALID_DEVICE": ( - f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n" - ), - "CUDA_ERROR_NOT_INITIALIZED": ( - f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n" - ), - "CUDA_ERROR_INVALID_VALUE": ( - f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n" - f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}" - ), - } - - error_suggestions = { - "CUDA_ERROR_INVALID_SOURCE": ( - f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", - f"2. Clear the compilation cache and regenerate the kernel", - f"3. Check CUDA toolkit installation", - ), - "CUDA_ERROR_NO_BINARY_FOR_GPU": ( - f"Set env CUTE_DSL_ARCH to match your GPU architecture", - ), - "CUDA_ERROR_OUT_OF_MEMORY": ( - f"1. Reduce batch size", - f"2. Reduce model size", - f"3. Free unused GPU memory", - ), - "CUDA_ERROR_INVALID_DEVICE": ( - f"1. Check if CUDA device is properly initialized", - f"2. Verify GPU is detected: nvidia-smi", - f"3. Check CUDA_VISIBLE_DEVICES environment variable", - ), - "CUDA_ERROR_NOT_INITIALIZED": ( - f"1. Check CUDA driver installation", - f"2. call `cuda.cuInit(0)` before any other CUDA operation", - f"3. Run nvidia-smi to confirm GPU status", - ), - "CUDA_ERROR_INVALID_VALUE": ( - f"1. Your GPU model", - f"2. SM ARCH setting", - f"3. Steps to reproduce", - ), - } - - message = error_messages.get( - error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}" - ) - - # Add debug information - debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n" - debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n" - debug_info += ( - f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n" - ) - - try: - # Get GPU information using CUDA Python API - debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n" - gpu_info = get_device_info() - debug_info += gpu_info.pretty_str() - - if target_arch and gpu_info.compatible_archs: - debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n" - - if target_arch not in gpu_info.compatible_archs: - debug_info += ( - f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n" - f"💡 Please use one of SM ARCHs: " - f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n" - ) - elif target_arch != gpu_info.sm_arch: - debug_info += ( - f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n" - f"• Current: {target_arch}\n" - f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n" - ) - else: - debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n" - - except Exception as e: - debug_info += ( - f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}" - ) - - return message, debug_info, error_suggestions.get(error_name, "") - - -class DSLCudaRuntimeError(DSLBaseError): - """ - Raised when an error occurs during CUDA runtime code generation in the DSL. - """ - - # Inherits all logic from DSLRuntimeError; override methods if you need - # specialized behavior or formatting for runtime errors. - def __init__(self, error_code, error_name) -> None: - self._error_code = error_code - self._error_name = error_name - message, debug_info, suggestion = _get_friendly_cuda_error_message( - error_code, error_name - ) - - super().__init__( - message, error_code=error_code, context=debug_info, suggestion=suggestion - ) - - -class DSLAstPreprocessorError(DSLBaseError): - """ - Raised when an error occurs during AST preprocessing or visiting in the DSL. - """ - - # Same approach: You could override _format_message if you want - # to emphasize AST node details or anything specific to preprocessing. - pass - - -class DSLNotImplemented(DSLBaseError): - """ - Raised when a feature of the DSL is not implemented yet. - """ - - # Useful for stubs in your DSL that you plan to implement in the future. - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py deleted file mode 100644 index f8b2da07ac9ac104f56c16a5cfcbbf01f01ee786..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py +++ /dev/null @@ -1,288 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides a class that compiles generated IR using MLIR's PassManager -and executes it using MLIR's ExecutionEngine. - -""" - -from typing import Sequence, Optional, Tuple -import os -import sys -import inspect -import argparse -from .common import DSLRuntimeError -from .utils.logger import log - -_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(_SCRIPT_PATH) - -from .._mlir import ir - - -# ============================================================================= -# Compiler Class -# ============================================================================= - - -class CompilationError(RuntimeError): - """Custom error class for compilation failures""" - - # Add ANSI color codes - RED = "\033[91m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - GREEN = "\033[92m" - BOLD = "\033[1m" - RESET = "\033[0m" - - def __init__( - self, - message: str, - nvvm_error: Optional[str] = None, - ir_context: Optional[str] = None, - cuda_toolkit: Optional[str] = None, - arch: Optional[str] = None, - ): - self.nvvm_error = nvvm_error - self.ir_context = ir_context - self.cuda_toolkit = cuda_toolkit - self.arch = arch - # Call parent with formatted error to avoid showing class name - super().__init__("") # Empty string to avoid class name - # Store formatted error for str() representation - self._formatted_error = self._format_error() - - def __str__(self) -> str: - """Override string representation to avoid showing class name""" - return self._formatted_error - - def __repr__(self) -> str: - """Override repr representation to avoid showing class name""" - return self._formatted_error - - def _format_error(self) -> str: - if not self.nvvm_error: - return str(self.args[0]) - - return f"""NVVM Compilation Error: ----------------------- - -{self.BLUE}⚙️ Current Settings:{self.RESET} -{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"} -- Target Architecture: {self.arch}{self.RESET} - -IR Context (truncated): -{self.ir_context} - -{self.YELLOW}💡 Possible Solutions:{self.RESET} -{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly -2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit -3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}""" - - -class Compiler: - """Compiler class for compiling and building MLIR modules.""" - - def __init__(self, passmanager, execution_engine): - self.passmanager = passmanager - self.execution_engine = execution_engine - - def __call__(self, module): - """Convenience application method.""" - self.compile(module) - - def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]: - """Process error message to extract NVVM error and IR context""" - nvvm_error = None - ir_msg = "" - - if "NVVM_ERROR" in error_msg: - # Extract the specific NVVM error - nvvm_error = ( - error_msg.split("libNVVM extra log:")[1].strip() - if "libNVVM extra log:" in error_msg - else error_msg - ) - - # Extract IR context - if "see current operation:" in error_msg: - # Get the IR section - ir_section = error_msg.split("see current operation:")[1].strip() - # Remove duplicate IR section - ir_section = ir_section.split("error: unknown: Failed translating")[ - 0 - ].strip() - - # Get first few lines and last few lines of the IR - ir_lines = ir_section.split("\n") - if len(ir_lines) > 10: - ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:]) - else: - ir_msg = ir_section - - return nvvm_error, ir_msg - - def compile( - self, - module, - pipeline: str, - cuda_toolkit: str = "", - arch: str = "", - enable_verifier=False, - ): - """Compiles the module by invoking the pipeline.""" - try: - pm = self.passmanager.PassManager.parse(pipeline) - pm.enable_verifier(enable_verifier) - pm.run(module.operation) - except Exception as e: - error_msg = str(e) - nvvm_error, ir_msg = self._process_error(error_msg) - - if nvvm_error: - raise CompilationError( - error_msg, - nvvm_error=nvvm_error, - ir_context=ir_msg, - cuda_toolkit=cuda_toolkit, - arch=arch, - ) from e - raise e - - def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()): - """Wraps the module in a JIT execution engine.""" - return self.execution_engine.ExecutionEngine( - module, opt_level=opt_level, shared_libs=shared_libs - ) - - def compile_and_jit( - self, - module, - pipeline: str, - shared_libs: Sequence[str] = (), - opt_level: int = 2, - cuda_toolkit: str = "", - arch: str = "", - ): - """Compiles and jits the module.""" - self.compile( - module, - pipeline, - cuda_toolkit, - arch, - ) - return self.jit(module, opt_level, shared_libs) - - -class CompileOptions: - def __init__(self, options: str = ""): - """ - This class encapsulates all compilation options relevant to function compilation. - It provides a convenient way to manage and pass compilation options, - particularly for controlling compilation settings. - By centralizing these options, it ensures consistent and flexible configuration of - compilation parameters such as optimization level, debugging control, etc. - - :param options: The options for the function. Will be parsed by argparse. - :type options: str - """ - if not isinstance(options, str): - raise DSLRuntimeError( - f"Invalid compilation `options`: {options}, it should be a string" - ) - self._parser = argparse.ArgumentParser() - self._parser.add_argument("--opt-level", nargs="?", type=int, default=3) - self._parser.add_argument( - "--enable-device-assertions", action="store_true", default=False - ) - self._parser.add_argument("--link-libraries", type=str, default="") - - try: - self._options = self._parser.parse_args(options.split()) - except SystemExit as e: - # catch argparse error and raise as DSLRuntimeError - raise DSLRuntimeError( - f"Invalid compile options: '{options}'. Please check the option values and format." - ) - log().info("`cute.compile` CompileOptions: options=" + options) - - def to_str(self): - """ - Generate a string representation of all compilation options - which will be used in pipeline options. - """ - option_strings = [] - for key, value in vars(self._options).items(): - hyphen_key = key.replace("_", "-") - if isinstance(value, bool): - formatted_value = "true" if value else "false" - else: - formatted_value = str(value) - option_strings.append(f"{hyphen_key}={formatted_value}") - - return " ".join(option_strings) - - -def compile(func, *args, **kwargs): - """ - This function is used to compile a `cute.jit` decorated function. - It will process the compile options and input parameters, do explicit compilation and return the jit executor. - - :param func: The function to compile. It can be a regular function, a method or a class instance. - :param args: The arguments to pass to the function. - :param kwargs: The keyword arguments to pass to the function. It can contain `options` like - `opt_level` to control the compilation flags. - - :return: The jit executor. - - :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. - """ - if func is None: - raise DSLRuntimeError("Function is not set or invalid.") - - if not callable(func): - raise DSLRuntimeError("Object is not callable.") - - kwargs["compile_only"] = True - kwargs["no_cache"] = True - - if inspect.isfunction(func): - # regular function - pass - elif inspect.ismethod(func): - # if it's a method, add the instance to the first argument - args = [func.__self__] + list(args) - func = func.__func__ - elif inspect.isclass(type(func)) and hasattr(func, "__call__"): - # If it's a class instance, get the class's __call__ method - args = [func] + list(args) - # Get the actual function from the class definition - func = func.__call__.__func__ - else: - raise DSLRuntimeError( - "Invalid function type, only function, method and module are supported, but got", - func, - ) - - # If it's a wrapped function created by jit decorator, get the original function - if hasattr(func, "__wrapped__"): - func = func.__wrapped__ - - if not hasattr(func, "_dsl_object"): - raise DSLRuntimeError("Function is not decorated with jit decorator.") - - # process compile options, extract the options and remove them from the kwargs - options = kwargs.pop("options", "") - func._dsl_object.compile_options = CompileOptions(options) - fcn_ptr = func._dsl_object._preprocess_and_execute(func) - return func._dsl_object._func(fcn_ptr, *args, **kwargs) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py deleted file mode 100644 index 2b17d22b1e6d7157a7f14334b0f29f1386c58c15..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py +++ /dev/null @@ -1,1686 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides a main DSL class for any Dialect. -The DSL should be inherited as a new class, and its initialization requires dialects. -It handles most of the mechanics for the DSL in an agnostic way, -for example, it can handle various dialect-specific tasks. -""" - - -# Standard library imports -from dataclasses import dataclass, field -import atexit -import os -import io -import sys -import errno -import ctypes -import re -import inspect -import argparse -import hashlib -from functools import lru_cache, wraps -from collections import namedtuple -from abc import ABC, abstractmethod -from typing import Any, Union, Tuple, get_origin, get_args, List -from types import FunctionType, SimpleNamespace -import warnings - -from . import typing as t -from .env_manager import EnvironmentVarManager -from .compiler import CompileOptions -from .ast_helpers import DSLOptimizationWarning - -# ============================================================================= -# CUDA Python -# ============================================================================= - -from ..base_dsl._mlir_helpers.arith import const - -# ============================================================================= -# Local module imports -# ============================================================================= - -from .cache_helpers import * -from .jit_executor import JitExecutor -from .utils.timer import timer -from .utils.logger import setup_log, log -from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe -from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry - -from .ast_preprocessor import DSLPreprocessor -from .common import * -from .typing import ( - get_c_pointers, - get_mlir_types, -) - -# ============================================================================= -# MLIR modules -# ============================================================================= - -from .._mlir import ir -from .._mlir import runtime as rt -from .._mlir.extras import types as T -from .._mlir.dialects import arith, math, func - -# ============================================================================= -# Global Variables -# ============================================================================= - -MLIR_DYNAMIC = -9223372036854775808 - -# ============================================================================= -# Codegen Utils -# ============================================================================= - - -def _numpy_type_to_mlir_type(dtype): - if dtype == np.float64: - return T.f64() - if dtype == np.float16: - return T.f16() - if dtype == np.float32: - return T.f32() - if dtype == np.int64: - return T.i64() - if dtype == np.int32: - return T.i32() - if dtype == np.int16: - return T.i16() - if dtype == np.int8: - return T.i8() - if dtype == np.uint64: - return T.ui64() - if dtype == np.uint32: - return T.ui32() - if dtype == np.uint16: - return T.ui16() - if dtype == np.uint8: - return T.ui8() - if dtype == np.bool_: - return T.bool() - if dtype == f8E5M2: - return T.f8E5M2() - if dtype == f8E4M3FN: - return T.f8E4M3FN() - if dtype == f8E8M0FNU: - return T.f8E8M0FNU() - if dtype == f6E3M2FN: - return T.f6E3M2FN() - if dtype == f6E2M3FN: - return T.f6E2M3FN() - if dtype == f4E2M1FN: - return T.f4E2M1FN() - assert False, f"Unknown type {type}" - - -def _mlir_type_to_numpy_type(type): - if type == T.f64(): - return np.float64 - if type == T.f16(): - return np.float16 - if type == T.f32(): - return np.float32 - if type == T.i64(): - return np.int64 - if type == T.i32(): - return np.int32 - if type == T.i16(): - return np.int16 - if type == T.i8(): - return np.int8 - if type == T.ui64(): - return np.uint64 - if type == T.ui32(): - return np.uint32 - if type == T.ui16(): - return np.uint16 - if type == T.ui8(): - return np.uint8 - if type == T.bool(): - return np.bool_ - assert False, f"Unknown type {type}" - - -# ============================================================================= -# Main DSL Class -# ============================================================================= - - -def is_dynamic_expression(value): - """ - Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value - """ - if isinstance(value, (tuple, list)): - for x in value: - if is_dynamic_expression(x): - return True - elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr( - value, "__extract_mlir_values__" - ): - return True - return False - - -def extract_mlir_values(obj): - """ - Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values - """ - res = [] - if hasattr(obj, "__extract_mlir_values__"): - res = obj.__extract_mlir_values__() - elif isinstance(obj, (tuple, list)): - res = sum((extract_mlir_values(x) for x in obj), []) - elif isinstance(obj, SimpleNamespace): - res = [] - for k, v in obj.__dict__.items(): - res.extend(extract_mlir_values(v)) - # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in extract_mlir_values to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - elif isinstance(obj, ir.Value): - res = [obj] - elif isinstance(obj, ir.BlockArgumentList): - res = list(obj) # type: ignore - - return res - - -def new_from_mlir_values(obj, values): - """ - Create a new python object by populating containing MLIR values with list of new values - """ - if hasattr(obj, "__new_from_mlir_values__"): - return obj.__new_from_mlir_values__(values) - elif isinstance(obj, (tuple, list)): - res = [] - for x in obj: - n_items = len(get_mlir_types(x)) - res.append(new_from_mlir_values(x, values[:n_items])) - values = values[n_items:] - obj_ty = type(obj) - return obj_ty(res) - elif isinstance(obj, SimpleNamespace): - res = SimpleNamespace() - for k, v in obj.__dict__.items(): - n_items = len(get_mlir_types(v)) - res.__dict__[k] = new_from_mlir_values(v, values[:n_items]) - values = values[n_items:] - return res - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in new_from_mlir_values to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - elif is_dynamic_expression(obj): - - if len(values) == 0: - return obj - - assert len(values) == 1 - return values[0] - else: - assert len(values) == 0, f"{obj} expects 0 values, but got {values}" - return obj - - -class DSLCallable: - """ - Wrapper class for a callable object used within the DSL. - - DSLCallable is designed to wrap a function and provide additional - introspection utilities such as retrieving the argument specification - and signature. It ensures that the wrapped function can only be called - once, after which the reference to the function is cleared to prevent - further invocations. This is useful in scenarios where a function should - only be executed a single time within the DSL's execution model. - - Attributes: - func (callable): The function to be wrapped and managed. - - Methods: - __call__(*args, **kwargs): Calls the wrapped function and clears it. - """ - - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwargs): - ret = self.__func__(*args, **kwargs) - self.func = None - return ret - - @property - def __func__(self): - assert self.func is not None, "DSLCallable is already called" - return self.func - - @property - def __signature__(self): - return inspect.signature(self.__func__) - - @property - def __name__(self): - return self.__func__.__name__ - - -class BaseDSL: - gpu_module = None - - def __init__( - self, - *, - name: str, - dsl_package_name: List[str], - compiler_provider: Any, - pass_sm_arch_name: str, - device_compilation_only=False, - preprocess=False, - ): - """ - Constructor for initializing the class with required providers and environment settings. - - Parameters: - - name (str): Name of DSL, used for environment variables and logging. - - package_name (str): Name of the package, used for the preprocessor. - - compiler_provider (MLIR dialect): Provider for compiler. - - pass_sm_arch_name (str): The keyword name of the SM. - - device_compilation_only (bool) : Only device code, and call it via cuda driver - - preprocess (bool): Enable AST transformation. - - This constructs a DSL instance and sets up environment management, - warning configurations, and logging functionalities. It reads - environment variables using `EnvironmentVarManager` and configures - a logger with settings from the environment. If environment warnings - are detected, they are escalated to errors to ensure strict handling. - """ - # Enforcing initialization of instance variables - if not all([name, compiler_provider, pass_sm_arch_name]): - raise DSLRuntimeError( - "All required parameters must be provided and non-empty" - ) - - self.name = name - self.compiler_provider = compiler_provider - self.pass_sm_arch_name = pass_sm_arch_name - self.frame = None - self.no_cache = False - self.device_compilation_only = device_compilation_only - self.num_kernels = 0 - # Read environment variables - self.envar = EnvironmentVarManager(self.name) - self.enable_preprocessor = preprocess - # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default - self.jit_cache = ( - dict() - if self.envar.disable_file_caching - else load_cache_from_path(self.name, self.envar.file_caching_capacity) - ) - self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}" - self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" - - # set warning - if not self.envar.enable_optimization_warnings: - # By default, optimization warnings are disabled - warnings.filterwarnings("ignore", category=DSLOptimizationWarning) - if self.envar.warnings_as_errors: - warnings.filterwarnings("error") - if self.envar.warnings_ignore: - warnings.filterwarnings("ignore") - - # Initialize logger - if self.envar.log_to_console == False and self.envar.jitTimeProfiling: - self.envar.log_to_console = True - self.envar.log_level = 20 # info level - setup_log( - self.name, - self.envar.log_to_console, - self.envar.log_to_file, - f"{self.name}.log", - self.envar.log_level, - ) - - # kernel symbols are temporary symbol string variables, their values are valid until the compilation is done. - self.kernel_symbols = [] - # used to generate unique name for gpu.launch - self.launch_inner_count = 0 - # initialize default compile options - self.compile_options = CompileOptions() - - if preprocess: - self.preprocessor = DSLPreprocessor(dsl_package_name) - log().info(f"Initializing {name} DSL") - log().debug(f"Logger initialized for {self.name}") - - # Hook excepthook - if self.envar.filterStacktrace: - origin_excepthook = sys.excepthook - module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__))) - - def excepthook(excep_type, value, traceback): - filter_exception(value, module_dir) - if hasattr(value, "__traceback__"): - origin_excepthook(excep_type, value, value.__traceback__) - else: - origin_excepthook( - excep_type, value, filter_stackframe(traceback, module_dir) - ) - - sys.excepthook = excepthook - - # Restore original excepthook - def restore_excepthook(hook): - sys.excepthook = hook - - atexit.register(restore_excepthook, origin_excepthook) - - def dump_cache(self): - if not self.envar.disable_file_caching: - dump_cache_to_path( - self.name, self.jit_cache, self.envar.file_caching_capacity - ) - - @lru_cache(maxsize=1) - def print_warning_once(self, message): - log().warning(f"Warning: {message}") - warnings.warn(message, UserWarning) - - def print_warning(self, message): - log().warning(f"Warning: {message}") - warnings.warn(message, UserWarning) - - @classmethod - @lru_cache(maxsize=1) - def _get_dsl(cls): - # Instantiate the DSL Class once - main_dsl = cls() - if not main_dsl.no_cache: - # register atexit callback - atexit.register(main_dsl.dump_cache) - return main_dsl - - @staticmethod - def _can_preprocess(**dkwargs): - """ - Check if AST transformation is enabled or not for `jit` and `kernel` decorators. - """ - return dkwargs.pop("preprocess", True) - - @staticmethod - def _get_original_function(fcn_ptr, name): - """ - Get the original function from the decorated function - """ - while fcn_ptr.__name__ != name: - # If the function is wrapped with functools, get from __wrapped__ - if hasattr(fcn_ptr, "__wrapped__"): - fcn_ptr = fcn_ptr.__wrapped__ - # If the function is wrapped manually, it's the first in clousure - elif callable(fcn_ptr.__closure__[0].cell_contents): - fcn_ptr = fcn_ptr.__closure__[0].cell_contents - else: - raise DSLRuntimeError( - f"Cannot find the original function {name} in the closure chain" - ) - return fcn_ptr - - @staticmethod - def _preprocess_and_execute(func): - """ - Run ast transformation and return the materialized function pointer - """ - if hasattr(func, "_transformed_ast"): - # If the function ptr is already materialized, use the existing one - func._dsl_object.frame = func._decorator_frame - if func._transformed_ast is None: - func._transformed_ast = func._dsl_object.run_preprocessor(func) - if func._transformed_ast is None: - del func._transformed_ast - func._dsl_object.frame = None - return func - - fcn_ptr = func._dsl_object.get_function_ptr(func) - # If the function is decorated, de-decorate it - fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__) - func._dsl_object.frame = None - return DSLCallable(fcn_ptr) - return func - - def jit_runner(self, executor, frame, *dargs, **dkwargs): - """ - Decorator to mark a function for JIT compilation. - """ - log().info("jit_runner") - - def jit_runner_decorator(func): - func._dsl_object = self - # Run preprocessor that alters AST - if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs): - # For an annotated function, add some DSL attributes - # When materializing the AST, we need decorator's frame - func._decorator_frame = frame - # No transformed ast at this point - func._transformed_ast = None - - @wraps(func) - def jit_wrapper(*args, **kwargs): - func_ptr = BaseDSL._preprocess_and_execute(func) - return executor(func_ptr, *args, **kwargs) - - return jit_wrapper - - if len(dargs) == 1 and callable(dargs[0]): - return jit_runner_decorator(dargs[0]) - else: - return jit_runner_decorator - - @classmethod - def jit(cls, *dargs, **dkwargs): - """ - Decorator to mark a function for JIT compilation for Host code. - """ - frame = inspect.currentframe().f_back - # Instantiate the DSL Class - main_dsl = cls._get_dsl() - return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs) - - @classmethod - def kernel(cls, *dargs, **dkwargs): - """ - Decorator to mark a function for JIT compilation for GPU. - """ - frame = inspect.currentframe().f_back - # Instantiate the DSL Class - main_dsl = cls._get_dsl() - return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs) - - @abstractmethod - def _kernel_helper(self, func, *args, **kwargs): - """ - Helper function to handle kernel generation logic - """ - pass - - @abstractmethod - def _build_gpu_module(self, attrs): - """ - Build the module op that contains the kernels. - """ - pass - - @abstractmethod - def _get_pipeline(self, pipeline): - """ - Get the pipeline from the other configuration options. - """ - if pipeline != None: - return pipeline - return None - - @staticmethod - def log_additions(func_type, operands=None, types=None, arg_attrs=None): - if operands is not None and operands != []: - log().debug( - f"Added {func_type} operands: [%s]", ", ".join(map(str, operands)) - ) - if types is not None: - log().debug( - f"Added {func_type} arg_types: [%s]", ", ".join(map(str, types)) - ) - if arg_attrs is not None: - log().debug( - f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs)) - ) - - def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): - """Does simple name mangling""" - - for spec_arg, arg in zip(args_spec.args, args): - spec_ty = args_spec.annotations.get(spec_arg, None) - if spec_ty != None: - if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)): - continue - if isinstance(spec_ty, (ir.Type, ir.Value)): - continue - if isinstance(arg, (ir.Type, ir.Value, ir.OpResult)): - continue - if isinstance(type(arg), (ir.Type, ir.Value, ir.OpResult)): - continue - if self._is_tensor_descriptor(arg): - continue - if inspect.isclass(spec_ty): - class_name = str(arg).replace("class", "") - class_name = class_name.replace(" ", "") - function_name = f"{function_name}_{class_name}" - elif isinstance(arg, (list, tuple)): - function_name = f"{function_name}_{'_'.join(map(str, arg))}" - else: - function_name = f"{function_name}_{arg}" - # we would need a dedicated MR to follow up - unwanted_chars = r"'-![]#,.<>()\":{}=%?@;" - translation_table = str.maketrans("", "", unwanted_chars) - function_name = function_name.translate(translation_table) - # identify address and drop - function_name = re.sub(r"0x[a-f0-9]{8,16}", "", function_name) - function_name = re.sub(r"\s+", " ", function_name) - function_name = function_name.replace(" ", "_") - function_name = function_name.replace("\n", "_") - # max fname is 256 character, leave space - function_name = function_name[:180] - log().info(f"Final mangled function name: {function_name}") - return function_name - - def _generate_execution_arguments_for_known_types( - self, arg, arg_spec, arg_name, i, fop_args, iv_block_args - ): - """ - Generate MLIR arguments for known types. - - Sub-DSLs can override this method to handle types that are not - natively supported by the Base DSL. - """ - ir_arg = [] - if is_argument_constexpr(arg, arg_spec, arg_name, i, func): - ir_arg.append(arg) - - return ir_arg, iv_block_args - - def generate_execution_arguments( - self, - args, - kwargs, - fop, - args_spec: inspect.FullArgSpec, - ): - """Create list of arguments that will be passed to MLIR's func.func op""" - - def gen_exec_args(input_args, arg_names, annotations, fop_args): - assert len(input_args) == len(arg_names) - - ir_args = [] - iv_block_args = 0 - for i, arg in enumerate(input_args): - arg_name = arg_names[i] - arg_spec = annotations.get(arg_name, None) - log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec) - - # Implicit cast to NumericMeta - if isinstance(arg_spec, t.NumericMeta) and not isinstance( - arg, arg_spec - ): - arg = t.cast(arg, arg_spec) - - ir_arg, iv_block_args = ( - self._generate_execution_arguments_for_known_types( - arg, arg_spec, arg_name, i, fop_args, iv_block_args - ) - ) - - if not ir_arg: - # If it's not a known type, try JIT argument adapter - # to convert the argument if possible - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - arg = adapter(arg) if adapter else arg - - n_args = len(get_mlir_types(arg)) - blk_args = fop_args[iv_block_args : iv_block_args + n_args] - ir_arg.append(new_from_mlir_values(arg, blk_args)) - iv_block_args += n_args - - self.log_additions(ir_arg) - ir_args.extend(ir_arg) - - return ir_args, iv_block_args - - fop_args = list(fop.regions[0].blocks[0].arguments) - ir_args, iv_block_args = gen_exec_args( - args, args_spec.args, args_spec.annotations, fop_args - ) - ir_kwargs, _ = gen_exec_args( - [kwargs[arg] for arg in args_spec.kwonlyargs], - args_spec.kwonlyargs, - args_spec.annotations, - fop_args[iv_block_args:], - ) - ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)} - - log().debug("execution args: %s", ", ".join(map(str, ir_args))) - log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs))) - return ir_args, ir_kwargs - - @abstractmethod - def _generate_mlir_type_for_tensor_descriptor(self, tensor): - """ - Generate MLIR type for the tensor descriptor. - """ - pass - - @abstractmethod - def _generate_executable_arg_for_tensor_descriptor( - self, mlir_value=None, ptr_tensor_ty=None, tensor=None - ): - """ - Generates executable value for the given tensor descriptor. - """ - pass - - def _get_globals(self): - """ - Combines global and local variables from the current context and the - caller's frame comes. This includes the current module's globals, the - global variables from the caller's frame, and the local variables from - the caller's frame. - - "self.frame" is used to fetch the caller's frame. - - AST preprocessor generates a new python code, so the resulting globals - dictionary is used to execute the python code. - """ - all_globals = {} - if self.frame: - all_globals.update(self.frame.f_globals) - all_globals.update(self.frame.f_locals) - return all_globals - - @abstractmethod - def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: - pass - - @abstractmethod - def _handle_tensor_descriptor( - self, maybe_tensor, arg_name: str, need_gpu_memory: bool - ) -> Any: - pass - - def _validate_arg(self, arg, arg_index, arg_name, arg_spec): - """ - Validates if the arg is really of the annotated type for type safety. - - The default implementation is empty. Subclasses can override this method to add more validation logic. - Returns None if validation passes, otherwise returns an error derived from DSLBaseError. - """ - pass - - def _generate_jit_func_args_for_known_types( - self, - func, - arg, - arg_name, - arg_spec, - arg_index, - *, - is_host=True, - ): - """ - Generate JIT function arguments for known types. - - Sub-DSLs can override this method to handle types that are not - natively supported by the Base DSL. - """ - - jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] - default_attr = ir.DictAttr.get({}) - - if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func): - jit_exec_arg = jit_arg_type = jit_arg_attr = None - - return jit_exec_arg, jit_arg_type, jit_arg_attr - - def _generate_jit_func_args( - self, - func, - function_name, - args, - kwargs, - args_spec: inspect.FullArgSpec, - *, - is_host=True, - ): - """Generate JIT function arguments.""" - - assert len(args) == len(args_spec.args) and len(kwargs) == len( - args_spec.kwonlyargs - ), ( - f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " - f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" - ) - - jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] - jit_adapted_args = [] - default_attr = ir.DictAttr.get({}) - - input_args = [*args, *kwargs.values()] - input_arg_names = [*args_spec.args, *args_spec.kwonlyargs] - for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)): - spec_ty = args_spec.annotations.get(arg_name, None) - log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty) - - # Implicitly convert into Numeric type if possible - if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty): - arg = t.cast(arg, spec_ty) - - # Type safety check - if spec_ty is not None: - err = self._validate_arg(arg, i, arg_name, spec_ty) - if err is not None: - raise err - - jit_exec_arg, jit_arg_type, jit_arg_attr = ( - self._generate_jit_func_args_for_known_types( - func, - arg, - arg_name, - spec_ty, - i, - is_host=is_host, - ) - ) - - if jit_arg_type is not None and len(jit_arg_type) == 0: - # If not any known type, try JIT argument adapter - # to convert the argument - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - if adapter: - arg = adapter(arg) - jit_adapted_args.append(arg) - - if is_host: - jit_exec_arg.extend(get_c_pointers(arg)) - jit_arg_type.extend(get_mlir_types(arg)) - else: - dyn_vals = extract_mlir_values(arg) - jit_exec_arg.extend(dyn_vals) - jit_arg_type.extend([v.type for v in dyn_vals]) - - if not jit_arg_type or not jit_exec_arg: - if (is_host and hasattr(arg, "__c_pointers__")) or ( - not is_host - and hasattr(arg, "__extract_mlir_values__") - and hasattr(arg, "__new_from_mlir_values__") - ): - pass - else: - raise DSLRuntimeError( - f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.", - context={ - f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.", - f"Call-site argument value": arg, - f"Call-site argument type": type(arg), - }, - suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` " - "if it's a value known at compile-time. " - f"Otherwise, implement the {'`JitArgument`' if is_host else '`DynamicExpression`'} " - f"protocol or register a custom JIT argument adapter for type `{type(arg)}` to " - "enable dynamic value conversion at runtime.", - ) - - jit_arg_attr.extend([default_attr] * len(jit_arg_type)) - - if jit_arg_type is not None: - jit_exec_args.extend(jit_exec_arg) - jit_arg_types.extend(jit_arg_type) - jit_arg_attrs.extend(jit_arg_attr) - - return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args - - def generate_mlir_function_types( - self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec - ): - """Convert input arguments to MLIR function signature also convert numpy arrays to memref.""" - - exe_args, types, attrs, adapted_args = self._generate_jit_func_args( - func, function_name, input_args, kwargs, args_spec, is_host=True - ) - - log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args))) - log().debug("Types: %s", ", ".join(map(str, types))) - - assert len(exe_args) == len( - types - ), "expects the same number of arguments and function parameters" - - return exe_args, types, adapted_args - - @dataclass - class LaunchConfig: - cluster: list = None - grid: list = field(default_factory=lambda: [1, 1, 1]) - block: list = field(default_factory=lambda: [1, 1, 1]) - smem: int = None - async_deps: list = field(default_factory=list) - has_cluster: bool = False - min_blocks_per_mp: int = 0 - auto_smem: bool = False - - def __post_init__(self): - if len(self.grid) != 3: - raise DSLRuntimeError(f"Expect 3d grid!") - if len(self.block) != 3: - raise DSLRuntimeError(f"Expect 3d block!") - - if self.smem is None: - self.smem = 0 - self.auto_smem = True - - self.has_cluster = self.cluster is not None - if self.cluster is None: - self.cluster = [None, None, None] - elif len(self.cluster) != 3: - raise DSLRuntimeError(f"Expect 3d cluster!") - - def diagnostic(self): - """Check command line parameters and enables diagnostic""" - # Check command line arguments "-diagnostic" - parser = argparse.ArgumentParser(description="Process diagnostic status.") - parser.add_argument( - "-diagnostic", - nargs="?", - const="all", - choices=["all", "fail", "success", "info", "suggestion"], - help="Set diagnostic status (fail, success, info, suggestion).", - ) - - args, _ = parser.parse_known_args() - ctx = ir.Context.current - - def callback(d): - print(f" [{self.name} Diagnostic] : {d.message}") - - ctx.attach_diagnostic_handler(callback) - - # Early return, don't enable diagnostics - if args.diagnostic is None: - return - - # Enable MLIR Flags - ctx.emit_error_diagnostics = True - ir._GlobalDebug.flag = True - if args.diagnostic == "all": - ir._GlobalDebug.set_types("diagnostic") - else: - ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}") - - def get_location(self): - """ - Get python location information and generate MLIR location - """ - - if self.frame is None: - log().debug("Frame is None") - return None - - file_loc = ir.Location.file( - self.frame.f_code.co_filename, self.frame.f_lineno, 0 - ) - - loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc) - return loc - - def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): - """ - Compile and JIT an MLIR module. - """ - - try: - self.diagnostic() - - orig_stdout = sys.stdout - orig_stderr = sys.stderr - sys.stderr = redirect_stderr = io.StringIO() - sys.stdout = redirect_stdout = io.StringIO() - - try: - kernel = self.compiler_provider.compile_and_jit( - module, - pipeline, - shared_libs=shared_libs, - cuda_toolkit=self.envar.cuda_toolkit, - arch=self.envar.arch, - ) - - finally: - sys.stdout = orig_stdout - sys.stderr = orig_stderr - ir._GlobalDebug.flag = False - - # Print captured output. - print(redirect_stdout.getvalue(), file=sys.stdout, end="") - print(redirect_stderr.getvalue(), file=sys.stderr, end="") - - return kernel - - except Exception as e: - raise DSLRuntimeError("🧊🧊🧊 ICE 🧊🧊🧊", cause=e) - finally: - pass - - def preprocess_pipeline(self, pipeline, arch) -> str: - - if self.envar.cuda_toolkit is None: - self.print_warning( - "CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath." - ) - - options = { - "toolkitPath": self.envar.cuda_toolkit if self.envar.cuda_toolkit else None, - self.pass_sm_arch_name: arch, - } - - opt_str = "" - for k, v in options.items(): - if v: - opt_str += f"{k}={v} " - - if opt_str: - # Automatically append the pipeline options if any is specified through env var - pattern = re.compile(r"{(.+)}") - match = pattern.search(pipeline) - if match: - opt_str = f"{{{match[1]} {opt_str}}}" - pipeline = re.sub(r"{.+}", opt_str, pipeline) - else: - pipeline = pipeline.rstrip(")") + f"{{{opt_str}}})" - log().debug(f"Using pipeline = {pipeline}") - return pipeline - - def get_shared_libs(self) -> list: - shared_libs = [] - support_libs = self.envar.shared_libs - if support_libs is not None: - _libs = support_libs.split(":") - for lib in _libs: - if not os.path.exists(lib): - raise FileNotFoundError( - errno.ENOENT, os.strerror(errno.ENOENT), lib - ) - shared_libs.append(lib) - else: - self.print_warning(f"{self.name}_LIBS environment variable is not set") - - return shared_libs - - @lru_cache(maxsize=1) - def get_version(self): - version_hash = hashlib.sha256() - - return version_hash - - def get_module_hash(self, module, function_name): - s = io.BytesIO() - module.operation.write_bytecode(s) - for attr, value in self.envar.__dict__.items(): - if value is not None: - s.write(str(value).encode()) - # Add compile options to the hash - s.write(self.compile_options.to_str().encode()) - module_hash = self.get_version().copy() - module_hash.update(s.getvalue()) - module_hash = module_hash.hexdigest() - - log().debug("Bytecode=[%s]", s.getvalue().hex()) - log().debug("Version=[%s]", self.get_version().hexdigest()) - log().info( - "Function=[%s] Computed module_hash=[%s]", function_name, module_hash - ) - return module_hash - - def build_module(self, module, function_name: str): - """ - Build the MLIR module, verify and return the module - """ - - # Save IR in a file - if self.envar.keepIR: - save_ir(self.name, module, function_name) - - if self.envar.printIR: - print("\n//===--- ------ Generated IR ------ ---====\n") - module.operation.print( - enable_debug_info=self.envar.generate_source_location - ) - print("\n//===--- --- End of Generated IR -- ---====\n") - - # Verify the module - try: - module.operation.verify() - except Exception as e: - raise DSLRuntimeError(f"🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e) - - return module - - def generate_original_ir( - self, - ir, - func, - funcBody, - kwargs, - function_name, - func_types, - gpu_module_attrs, - args, - args_spec, - ): - # This location is set to None for now; otherwise, calls to the same - # function on different lines would produce different line numbers, - # which would break the cache. - loc = None # self.get_location() - - def build_ir_module(): - module = ir.Module.create(loc=loc) - unit_attr = ir.UnitAttr.get() - module.operation.attributes["gpu.container_module"] = unit_attr - - with ir.InsertionPoint(module.body): - # Always generate gpu module. It's canonicalized by the compiler when it's not used. - self._build_gpu_module(gpu_module_attrs) - - fop = func.FuncOp(function_name, (func_types, []), loc=loc) - fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - log().debug("Generated Function OP [%s]", fop) - with ir.InsertionPoint(fop.add_entry_block()): - ir_args, ir_kwargs = self.generate_execution_arguments( - args, kwargs, fop, args_spec - ) - # Call user function body - try: - result = funcBody(*ir_args, **ir_kwargs) - func.ReturnOp([]) - except NameError as name_error: - raise DSLRuntimeError( - f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥", - cause=name_error, - suggestion="Using variables defined in dynamic control flow is not supported. Please give an initial value before control flow.", - ) - except DSLRuntimeError as dsl_error: - # Throw it's already a DSL error - raise dsl_error - return module, result - - # Build IR module - profiler = timer(enable=self.envar.jitTimeProfiling) - module, result = profiler(build_ir_module)() - module_hash = self.get_module_hash(module, function_name) - - module = self.build_module(module, function_name) - - return module, module_hash, result - - def compile_and_cache( - self, module, module_hash, function_name, pipeline, args_spec, no_cache - ): - arch = self.envar.arch - pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch) - shared_libs = self.get_shared_libs() - profiler = timer(enable=self.envar.jitTimeProfiling) - if ( - no_cache - or module_hash not in self.jit_cache - or self.jit_cache[module_hash].ir_module is None - ): - log().info( - "JIT cache miss function=[%s] module_hash=[%s]", - function_name, - module_hash, - ) - # Compile and JIT MLIR module - engine = profiler(self.compile_and_jit)( - module, pipeline, shared_libs, function_name=function_name - ) - else: - log().info( - "JIT cache hit IN-FILE function=[%s] module_hash=[%s]", - function_name, - module_hash, - ) - module = self.jit_cache[module_hash].ir_module - engine = self.compiler_provider.jit(module, shared_libs=shared_libs) - capi_func = profiler(engine.lookup)(function_name) - jit_executor = JitExecutor( - self, - engine, - capi_func, - module, - args_spec, - function_name, - jit_time_profiling=self.envar.jitTimeProfiling, - ) - jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols) - - if not no_cache: - # module stored in cache is compiled. - self.jit_cache[module_hash] = jit_executor - - return jit_executor - - def post_compilation_cleanup(self): - """Clean up some internal state after one compilation is completed.""" - # clear the kernel symbols after the compilation is done. - self.kernel_symbols = [] - self.launch_inner_count = 0 - # reset num_kernels to 0 for next compilation. - self.num_kernels = 0 - # reset the compile options after the compilation is done. - self.compile_options = CompileOptions() - - def generate_mlir( - self, - funcBody, - kwargs, - function_name, - gpu_module_attrs, - args, - args_spec, - pipeline, - no_cache, - compile_only, - loc=None, - ): - """Generate MLIR module and compile iself.T_provider.""" - with ir.Context(), ir.Location.unknown(): - # Convert input arguments to MLIR arguments - exe_args, func_types, adapted_args = self.generate_mlir_function_types( - funcBody, function_name, args, kwargs, args_spec - ) - - # Generate original ir module and its hash value. - module, module_hash, result = self.generate_original_ir( - ir, - func, - funcBody, - kwargs, - function_name, - func_types, - gpu_module_attrs, - args, - args_spec, - ) - - # dryrun is used to only generate IR - if self.envar.dryrun: - return result - - if ( - no_cache - or module_hash not in self.jit_cache - or self.jit_cache[module_hash].capi_func is None - ): - # no cache or cache miss, do ir generation/compilation/jit engine - jit_executor = self.compile_and_cache( - module, module_hash, function_name, pipeline, args_spec, no_cache - ) - else: - # cache hit - log().info( - "JIT cache hit IN-MEMORY function=[%s] module_hash=[%s]", - function_name, - module_hash, - ) - jit_executor = self.jit_cache[module_hash] - - self.post_compilation_cleanup() - # If compile_only is set, bypass execution return the jit_executor directly - if compile_only: - return jit_executor - # Run the compiled program - jit_executor.run_compiled_program(exe_args) - - return result - - def run_preprocessor(self, funcBody): - if not hasattr(funcBody, "_preprocessed"): - function_name = funcBody.__name__ - self.funcBody = funcBody - log().info("Started preprocessing [%s]", function_name) - exec_globals = self._get_globals() - transformed_ast = self.preprocessor.transform(funcBody, exec_globals) - if self.envar.print_after_preprocessor: - log().info( - f"# Printing unparsed AST after preprocess of func=`{function_name}` id=`{id(funcBody)}`" - ) - DSLPreprocessor.print_ast(transformed_ast) - funcBody._preprocessed = True - return transformed_ast - return None - - def get_function_ptr(self, original_function): - file_name = inspect.getsourcefile(original_function) - code_object = compile( - original_function._transformed_ast, filename=file_name, mode="exec" - ) - return self.preprocessor.exec( - original_function.__name__, - original_function, - code_object, - self._get_globals(), - ) - - def _get_function_bound_args(self, sig, func_name, *args, **kwargs): - """ - Binds provided arguments to a function's signature and applies default values. - - E.g. given a function signature `def foo(a, b=2, c=3)`, and at call-site if we do - `foo(a=1, c=4)`, the returned BoundArguments object will have args = `[1]` - and kwargs = `{'b': 2, 'c': 4}` - - An exception will be raised if binding fails. - """ - try: - bound_args = sig.bind_partial(*args, **kwargs) - bound_args.apply_defaults() - except Exception as e: - raise DSLRuntimeError( - f"Failed to bind arguments to function `{func_name}` with signature `{sig}`", - cause=e, - ) - return bound_args - - def _canonicalize_args(self, sig, *args, **kwargs): - """ - Canonicalize the input arguments so that returned args only contain - positional arguments and kwargs only contain keyword arguments. - """ - function_name = self.funcBody.__name__ - bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) - canonicalized_args = bound_args.args - canonicalized_kwargs = bound_args.kwargs - return canonicalized_args, canonicalized_kwargs - - def _check_arg_count(self, *args, **kwargs): - if not self.funcBody: - raise DSLRuntimeError("Function body is not set.") - - # Pass the actual function object to inspect.signature to get the signature. - sig = inspect.signature(self.funcBody) - - function_name = self.funcBody.__name__ - - bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) - - # Check if all non-default arguments are provided - for param in sig.parameters.values(): - if ( - param.default is inspect.Parameter.empty - and param.name not in bound_args.arguments - ): - raise DSLRuntimeError( - f"Missing required argument in `{function_name}`: '{param.name}'" - ) - - return sig - - def _func(self, funcBody, *args, **kwargs): - """Decorator for MLIR functions. - It cuts the boilerplate code, does the following: - 1. Generates `func.func` - 2. Types translation (numpy arrays -> cute.memref, float -> , etc.) - 3. Compiles and JITs the MLIR module - 4. Invokes the generated function - 5. Operator overloading (a + b --> arith.addi a, b) - 6. Generates GPU kernel function with GPU module and kernel attributes baked - """ - if ir.Context.current is None: - pass - elif ir.InsertionPoint.current is not None: - return funcBody(*args, **kwargs) - - function_name = funcBody.__name__ - self.funcBody = funcBody - - pipeline = kwargs.pop("pipeline", None) - gpu_module_attrs = kwargs.pop("gpu_module_attrs", {}) - - # Disable cache - no_cache = kwargs.pop("no_cache", False) - - # Always compile(disable cache) and return the result jit_executor - compile_only = kwargs.pop("compile_only", False) - - if not no_cache and compile_only: - no_cache = True - self.print_warning("Cache is disabled as user wants to compile only.") - - # Check the number of arguments - sig = self._check_arg_count(*args, **kwargs) - - args_spec = inspect.getfullargspec(funcBody) - - # Canonicalize the input arguments - canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - sig, *args, **kwargs - ) - - # Simple name mangling - function_name = self.mangle_name(function_name, canonicalized_args, args_spec) - - # Generate MLIR Context and start generating IR - log().debug(f"Generating MLIR for function '{function_name}'") - result = self.generate_mlir( - funcBody, - canonicalized_kwargs, - function_name, - gpu_module_attrs, - canonicalized_args, - args_spec, - pipeline, - no_cache, - compile_only, - ) - - return result - - class _KernelGenHelper(ABC): - def __init__(self): - self.func_op = None - self.func_type = None - - @abstractmethod - def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): - assert arg_types is not None, "Invalid arg_types!" - assert kernel_name is not None, "kernel name is empty" - pass - - @abstractmethod - def generate_func_ret_op(self): - pass - - @abstractmethod - def generate_launch_op(self, *args, **kwargs): - pass - - @abstractmethod - def get_func_body_start(self): - pass - - @abstractmethod - def enter_gpu_module(module): - """Compute the insertion point into the given module.""" - pass - - @lru_cache(maxsize=1) - def _get_default_stream(self): - """Returns the default stream 0""" - from .runtime import cuda as cuda_helpers - - return cuda_helpers.stream_create() - - def _execute_cuda( - self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None - ): - """ - Executes a specified CUDA kernel from a cubin file, handling module loading, - kernel retrieval, stream creation, kernel launch, and synchronization. - """ - from .runtime import cuda as cuda_helpers - - # Step 1. Load CUDA Module - module = cuda_helpers.load_cubin_module(fname_cubin) - # Step 2. Find CUDA function - kernel_ptr = cuda_helpers.get_kernel_function(module, kernel_name) - - sync_execution_default = False - if stream is None: - stream = self._get_default_stream() - sync_execution_default = True - - # Step 4. Launch the kernel - cuda_helpers.launch_kernel( - kernel_ptr, - grid_size, - block_size, - stream, - smem_size=smem_size, - kernel_args=self.exe_args, - ) - - if sync_execution_default: - # Step 5. Optional Sync cuda stream - cuda_helpers.stream_sync(stream) - - def _execute_by_cuda_driver( - self, - kernel_generator, - generate_cubin, - grid_size, - block_size, - smem_size, - stream=None, - ): - """ - This function builds IR and execute the module using cuda driver. - It doesn't use mlir's cuda runtime - """ - ret = None - - # Step 1. Build IR - with ir.Context(), ir.Location.unknown(): - loc = self.get_location() - module = ir.Module.create(loc=loc) - unit_attr = ir.UnitAttr.get() - module.operation.attributes["gpu.container_module"] = unit_attr - with ir.InsertionPoint(module.body): - self._build_gpu_module() - ret, kernel_name = kernel_generator() - log().debug( - f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}" - ) - - module = self.build_module(module, kernel_name) - - # dryrun is used to only generate IR - if self.envar.dryrun: - return ret - - # Generate cubin - fname_cubin = generate_cubin(module, kernel_name) - - # Execute a cuda kernel from cubin - self._execute_cuda( - fname_cubin, kernel_name, grid_size, block_size, smem_size, stream - ) - - return ret - - def generate_kernel_operands_and_types( - self, kernel_func, kernel_name, args_spec, args, kwargs - ): - """ - Generate the operands and types for the kernel function - """ - - kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], [] - - log().debug( - "Processing GPU kernel call in [%s] mode", - ( - f"Only {self.device_jit_decorator_name}" - if self.device_compilation_only - else f"{self.host_jit_decorator_name} + {self.device_jit_decorator_name}" - ), - ) - - if self.device_compilation_only: - return kernel_operands, kernel_arg_types, kernel_arg_attrs - - kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = ( - self._generate_jit_func_args( - kernel_func, kernel_name, args, kwargs, args_spec, is_host=False - ) - ) - - log().debug("Final kernel_operands: %s", ", ".join(map(str, kernel_operands))) - log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types))) - log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs))) - - assert ( - len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs) - ), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal" - - return kernel_operands, kernel_arg_types, kernel_arg_attrs - - def kernel_launcher(self, *dargs, **dkwargs): - def decorator(funcBody): - @wraps(funcBody) - def kernel_wrapper(*args, **kwargs): - """ - Base decorator for generating kernel function - - This decorator provides a template for kernel function generation - including kernel function header/body and kernel launch op at call site - - Optional arguments (with default value in <>): - - requiredArgs <[]>: specifies the mandatory arguments that must present in kernel function signature - the args will be validated and collected as a namedtuple - - optionalArgs <[]>: specifies the optional arguments that might present in kernel function signature - the args will be collected (if present) as a namedtuple - - unitAttrNames <[]>: specifies the name(s) of ir.UnitAttr to be set for kernel function op - - valueAttrDict <{}>: specifies the name(s) and value(s) of ir.Attribute to be set for kernel function op - - kernelGenHelper : specifies the mandatory customized kernel generation helper class (derived from _KernelGenHelper) - - Return value: - A namedtuple "KernelReturns" is returned with following fields: - - kernel_func_ret: the return of the kernel function - - launch_op_ret: the return of the launch op - """ - - requiredArgs = dkwargs.get("requiredArgs", []) - optionalArgs = dkwargs.get("optionalArgs", []) - unitAttrNames = dkwargs.get("unitAttrNames", []) - valueAttrDict = dkwargs.get("valueAttrDict", {}) - kernelGenHelper = dkwargs.get("kernelGenHelper", None) - - kernel_name = funcBody.__name__ - args_spec = inspect.getfullargspec(funcBody) - self.funcBody = funcBody - - # Give each kernel a unique name. (The same kernel may be - # called multiple times, resulting in multiple kernel traces.) - # The mangled name of Python function is part of the name to - # improve readability. - kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}" - self.num_kernels += 1 - - # Step 0. Preprocess the arguments - def extract_args(argNames, assertIfNone=False) -> list: - extracted = [] - for name in argNames: - value = kwargs.pop(name, None) - if assertIfNone and value is None: - raise DSLRuntimeError( - f"{name} is required for {kernel_name}" - ) - extracted.append(value) - - return extracted - - RequiredArgs = namedtuple("RequiredArgs", requiredArgs) - req_args = ( - RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True)) - if requiredArgs - else None - ) - OptionalArgs = namedtuple("OptionalArgs", optionalArgs) - opt_args = ( - OptionalArgs._make(extract_args(optionalArgs)) - if optionalArgs - else None - ) - assert ( - kernelGenHelper is not None - ), "kernelGenHelper should be explicitly specified!" - - # check arguments - sig = self._check_arg_count(*args, **kwargs) - - # Canonicalize the input arguments - canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - sig, *args, **kwargs - ) - - kernel_operands, kernel_types, kernel_arg_attrs = ( - self.generate_kernel_operands_and_types( - funcBody, - kernel_name, - args_spec, - canonicalized_args, - canonicalized_kwargs, - ) - ) - - with self._enter_gpu_module(): - log().debug("Generating device kernel") - if self.device_compilation_only: - log().debug("Generating cuda-python arguments") - # Convert input arguments to MLIR arguments - self.exe_args, kernel_types, _ = ( - self.generate_mlir_function_types( - funcBody, - kernel_name, - canonicalized_args, - canonicalized_kwargs, - args_spec, - ) - ) - - helper = kernelGenHelper() - loc = self.get_location() - fop = helper.generate_func_op( - kernel_types, kernel_arg_attrs, kernel_name, loc - ) - log().debug(f"Kernel function op: {fop}") - for attr in unitAttrNames: - fop.attributes[attr] = ir.UnitAttr.get() - for key, val in valueAttrDict.items(): - fop.attributes[key] = val - - fop.sym_visibility = ir.StringAttr.get("public") - with ir.InsertionPoint(helper.get_func_body_start()): - ir_args, ir_kwargs = self.generate_execution_arguments( - canonicalized_args, canonicalized_kwargs, fop, args_spec - ) - log().debug( - f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}" - ) - # Call user function body - kernel_ret = funcBody(*ir_args, **ir_kwargs) - helper.generate_func_ret_op() - - # Step 3. Generate call site `launch_func` - kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name]) - launch_ret = helper.generate_launch_op( - kernelSym=kernel_sym, - kernelOperands=kernel_operands, - requiredArgs=req_args, - optionalArgs=opt_args, - ) - - KernelReturns = namedtuple( - "KernelReturns", ["kernel_func_ret", "launch_op_ret"] - ) - result = KernelReturns( - kernel_func_ret=kernel_ret, launch_op_ret=launch_ret - ) - log().debug(f"Kernel result: {result}, kernel name: {kernel_name}") - return result, kernel_name - - return kernel_wrapper - - if len(dargs) == 1 and callable(dargs[0]): - return decorator(dargs[0]) - else: - return decorator diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py deleted file mode 100644 index fa683477f3fb5b18f5459e19bdd468432590b952..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py +++ /dev/null @@ -1,320 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides utilities for the environment variables setup. - -It provides an EnvironmentVarManager, which reads environment variables for the DSL -and caches them for efficient access. - -It also provides utilities to automatically setup a subset of environment variables -based on heuristics. -""" - -import os -import sys -import shutil -import glob -from pathlib import Path -from functools import lru_cache -from typing import Any - -from ..base_dsl.runtime.cuda import get_compute_capability_major_minor -from .utils.logger import log - -IS_WINDOWS = sys.platform == "win32" -CLIB_EXT = ".dll" if IS_WINDOWS else ".so" - -# ============================================================================= -# Environment Variable Helpers -# ============================================================================= - - -@lru_cache(maxsize=None) -def get_str_env_var(var_name, default_value=None): - value = os.getenv(var_name) - return value if value is not None else default_value - - -@lru_cache(maxsize=None) -def get_bool_env_var(var_name, default_value=False): - value = get_str_env_var(var_name) - if value is None: - return default_value - return value not in {"False", "0", ""} - - -@lru_cache(maxsize=None) -def get_int_env_var(var_name, default_value=0): - value = get_str_env_var(var_name) - return int(value) if value and value.isdigit() else default_value - - -@lru_cache(maxsize=None) -def has_env_var(var_name): - return os.getenv(var_name) is not None - - -def detect_gpu_arch(prefix): - """ - Attempts to detect the machine's GPU architecture. - - Returns: - A string representing the GPU architecture (e.g. "70" for compute capability 7.0), - or a default value(e.g. "sm_100") if the GPU architecture cannot be determined. - """ - arch = (None, None) - try: - arch = get_compute_capability_major_minor() - except Exception as e: - log().info(f"Failed to get CUDA compute capability: {e}") - - if arch == (None, None): - # default to sm_100 - arch = (10, 0) - - major, minor = arch - suffix = "" - if major >= 9: - suffix = "a" - - return f"sm_{major}{minor}{suffix}" - - -def find_libs_in_ancestors(start, target_libs, lib_folder_guesses): - """ - Search ancestor directories for a candidate library folder containing all required libraries. - - Starting from the given path, this function traverses up through each parent directory. - For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses) - for files that match the required library extension (CLIB_EXT). Library file names are - canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains - all of the required libraries (as specified in target_libs), the function returns a list of - absolute paths to these library files. - - Parameters: - start (str or Path): The starting directory from which to begin the search. - target_libs (iterable of str): A collection of required library names (without the "lib" prefix). - lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries. - - Returns: - list[str] or None: A list of resolved paths to the required library files if found; otherwise, None. - """ - # Traverse through all parent directories of the resolved starting path. - for ancestor in Path(start).resolve().parents: - # Iterate over each candidate relative directory path. - for rel_path in lib_folder_guesses: - target_dir = ancestor / rel_path - # Skip if the candidate directory does not exist. - if not target_dir.is_dir(): - continue - - # Initialize a list to hold the resolved paths of matching library files. - libs_cand = [] - # Create a set of the remaining libraries we need to find. - remaining_libs = set(target_libs) - - # Iterate over all items in the candidate directory. - for p in target_dir.iterdir(): - # Consider only files with the expected library extension. - if p.suffix == CLIB_EXT: - # Canonicalize the library name by removing the "lib" prefix. - lib_name = p.stem.removeprefix("lib") - # If this library is required, add its resolved path and mark it as found. - if lib_name in remaining_libs: - libs_cand.append(str(p.resolve())) - remaining_libs.remove(lib_name) - - # If all required libraries have been found, return the list of library paths. - if len(remaining_libs) == 0: - return libs_cand - - # Return None if no candidate directory contains all required libraries. - return None - - -def _find_cuda_home(): - """Find the CUDA installation path using a series of heuristic methods. - Methods below are checked in order, and the function returns on first match: - 1. Checking the environment variables CUDA_HOME and CUDA_PATH. - 2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda. - 3. Scanning common installation directories based on the operating system. - - On Windows systems (when IS_WINDOWS is True), it searches in: - C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.* - - On Unix-like systems, it searches in: - /usr/local/cuda* - - Returns: - Optional[str]: The absolute CUDA installation path if found; otherwise, None. - - Note: - The variable IS_WINDOWS is defined in the module scope. - """ - # Guess #1 - cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH") - if cuda_home is None: - # Guess #2 - nvcc_path = shutil.which("nvcc") - if nvcc_path is not None: - cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) - else: - # Guess #3 - if IS_WINDOWS: - glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*" - else: - glob_pat = "/usr/local/cuda*" - cuda_homes = glob.glob(glob_pat) - if len(cuda_homes) == 0: - cuda_home = "" - else: - cuda_home = cuda_homes[0] - if not os.path.exists(cuda_home): - cuda_home = None - return cuda_home - - -def get_cuda_toolkit_path(): - """ - Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if - set. Otherwise, attempts to discover a valid CUDA toolkit location and - return. If not found, return None. - """ - # Check if the environment variable is already set, if so, return it immediately. - try: - cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH") - if cuda_toolkit_path_existing: - return cuda_toolkit_path_existing - - found_cuda_home = _find_cuda_home() - if found_cuda_home: - return found_cuda_home - except Exception as e: - log().info("default_env: exception on get_cuda_toolkit_path", e) - return None - - -def get_prefix_dsl_libs(prefix: str): - """ - Returns get_str_env_var('{prefix}_LIBS') if set. - Otherwise, attempts to discover libs based on heuristics and return - If not found, return None. - """ - # Check if the environment variable is already set, if so, return it immediately. - try: - prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS") - if prefix_libs_existing: - return prefix_libs_existing - - def get_libs_cand(start): - target_libs = { - "mlir_c_runner_utils", - "mlir_runner_utils", - "mlir_cuda_runtime", - } - lib_folder_guesses = [ - "lib", - ] - - libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses) - if libs_cand: - dsl_libs = ":".join(libs_cand) - return dsl_libs - - return None - - # find from install folder - dsl_libs = get_libs_cand(__file__) - - if not dsl_libs: - # try to find from build folder structure - dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve()) - - return dsl_libs - - except Exception as e: - log().info(f"default_env: exception on get_prefix_dsl_libs", e) - return None - - -class EnvironmentVarManager: - """Manages environment variables for configuration options. - - Printing options: - - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False) - - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False) - - [DSL_NAME]_PRINT_IR: Print generated IR (default: False) - - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) - File options: - - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) - - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) - Other options: - - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). - - [DSL_NAME]_DRYRUN: Generates IR only (default: False) - - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") - - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) - - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) - - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) - - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) - - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) - - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) - - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) - - [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False) - """ - - def __init__(self, prefix="DSL"): - self.prefix = prefix # change if needed - - # Printing options - self.print_after_preprocessor = get_bool_env_var( - f"{prefix}_PRINT_AFTER_PREPROCESSOR", False - ) - self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False) - self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) - # File options - self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False) - # Logging options - self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) - self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False) - if ( - has_env_var(f"{prefix}_LOG_LEVEL") - and not self.log_to_console - and not self.log_to_file - ): - log().warning( - f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!" - ) - self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) - - # Other options - self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) - self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) - self.warnings_as_errors = get_bool_env_var( - f"{prefix}_WARNINGS_AS_ERRORS", False - ) - self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False) - self.enable_optimization_warnings = get_bool_env_var( - f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False - ) - self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False) - self.disable_file_caching = get_bool_env_var( - f"{prefix}_DISABLE_FILE_CACHING", False - ) - self.file_caching_capacity = get_int_env_var( - f"{prefix}_FILE_CACHING_CAPACITY", 1000 - ) - self.generate_source_location = not get_bool_env_var( - f"{prefix}_NO_SOURCE_LOCATION", False - ) - # set cuda - self.cuda_toolkit = get_cuda_toolkit_path() - - # set mlir shared libraries - self.shared_libs = get_prefix_dsl_libs(prefix) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py deleted file mode 100644 index 83268009c85ef64967d6a81ab886ebeb704f140d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py +++ /dev/null @@ -1,357 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides jit executor related classes -""" -import ctypes -import inspect -import io -from typing import get_origin - -import numpy as np - -# MLIR modules imports -from .._mlir import ir - -# Local modules imports -from . import typing as t -from .common import DSLRuntimeError -from .runtime import cuda as cuda_helpers -from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr -from .typing import get_c_pointers -from .utils.logger import log -from .utils.timer import timer - - -class CudaSingleModule: - def __init__(self, cuda_module, kernel_ptr): - self.cuda_module = cuda_module - self.kernel_ptr = kernel_ptr - - -class CudaModules: - def __init__(self, modules, args): - # list of CudaSingleModule - self.modules = modules - # extra kernel ptr arguments for launch - self.args = args - - -class JitExecutor: - def __init__( - self, - dsl, - engine, - capi_func, - ir_module, - args_spec, - function_name, - cuda_modules: CudaModules = None, - jit_time_profiling=False, - ): - self.dsl = dsl - self.engine = engine - self.capi_func = capi_func - self.ir_module = ir_module - self.args_spec = args_spec - self.function_name = function_name - if args_spec is not None: - self.original_args_spec = args_spec - self.args_spec = self.filter_runtime_arg_spec(args_spec) - # cuda kernels - self.cuda_modules = cuda_modules - self.jit_time_profiling = jit_time_profiling - - def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec): - runtime_args = [] - runtime_annotations = {} - runtime_defaults = [] - - # Calculate the offset where defaults start in the original args - if arg_spec.defaults: - defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) - else: - defaults_start_idx = len(arg_spec.args) - - # Filter arguments and maintain their properties - for i, arg_name in enumerate(arg_spec.args): - arg_type = arg_spec.annotations.get(arg_name, None) - - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): - continue - - # Keep runtime arguments - runtime_args.append(arg_name) - if arg_name in arg_spec.annotations: - runtime_annotations[arg_name] = arg_type - - # Keep corresponding default if it exists - if i >= defaults_start_idx: - default_idx = i - defaults_start_idx - runtime_defaults.append(arg_spec.defaults[default_idx]) - - # Filter kwonlyargs and their defaults - runtime_kwonlyargs = [] - runtime_kwonlydefaults = {} - - if arg_spec.kwonlyargs: - for kwarg in arg_spec.kwonlyargs: - arg_type = arg_spec.annotations.get(kwarg, None) - - # Apply same filtering logic - if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): - continue - - runtime_kwonlyargs.append(kwarg) - if kwarg in arg_spec.annotations: - runtime_annotations[kwarg] = arg_type - if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: - runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg] - - # Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec) - runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None - - return inspect.FullArgSpec( - args=runtime_args, - varargs=arg_spec.varargs, # Keep original varargs - varkw=arg_spec.varkw, # Keep original varkw - defaults=runtime_defaults, - kwonlyargs=runtime_kwonlyargs, - kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None, - annotations=runtime_annotations, - ) - - def __del__(self): - if self.cuda_modules: - cuda_modules = [module.cuda_module for module in self.cuda_modules.modules] - for module in set(cuda_modules): - cuda_helpers.unload_cubin_module(module) - - def get_constexpr_args(self) -> list[dict[str, int | str]]: - """ - This function returns the constexpr args that have been pruned from the original function signature. - The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). - - :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). - :rtype: list[dict[str, int | str]] - """ - if self.original_args_spec is None: - return list() - constexpr_args = list() - for i, arg_name in enumerate(self.original_args_spec.args): - if arg_name not in self.args_spec.args: - constexpr_args.append({"argument_index": i, "argument_name": arg_name}) - - if self.original_args_spec.kwonlyargs: - for kwarg in self.original_args_spec.kwonlyargs: - if kwarg not in self.args_spec.kwonlyargs: - constexpr_args.append( - {"argument_index": None, "argument_name": kwarg} - ) - return constexpr_args - - def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec): - """ - This function is the prune version of `generate_mlir_function_types` which only generates execution args - to get rid of mlir context. - """ - - # Process positional arguments with defaults - rectified_args = list(args) - if args_spec.defaults and len(args) < len(args_spec.args): - rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) - for k, v in kwargs.items(): - if k in args_spec.args: - idx = args_spec.args.index(k) - if idx < len(rectified_args): - rectified_args[idx] = v - else: - rectified_args.append(v) - - # Process keyword arguments - rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} - if args_spec.kwonlydefaults and len(rectified_kwargs) < len( - args_spec.kwonlyargs - ): - rectified_kwargs.update(args_spec.kwonlydefaults) - - # args/kwargs must match arg_specs - if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( - args_spec.kwonlyargs - ): - raise DSLRuntimeError( - "input args/kwargs length does not match runtime function signature!", - context={ - "input args length": len(rectified_args), - "input kwargs length": len(rectified_kwargs), - "function signature args length": len(args_spec.args), - "function signature kwonlyargs length": len(args_spec.kwonlyargs), - }, - ) - - exe_args = [] - adapted_args = [] - input_args = rectified_args + list(rectified_kwargs.values()) - input_arg_names = args_spec.args + args_spec.kwonlyargs - for arg, arg_name in zip(input_args, input_arg_names): - # short-cut for args already converted - if hasattr(arg, "__c_pointers__"): - exe_args.extend(arg.__c_pointers__()) - continue - - arg_type = args_spec.annotations.get(arg_name, None) - - # Implicit cast to NumericMeta - if isinstance(arg_type, t.NumericMeta): - arg = t.cast(arg, arg_type) - else: - # If not any known type, try registered adapter to do the conversion - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - if adapter: - arg = adapter(arg) - adapted_args.append(arg) - - exe_args.extend(get_c_pointers(arg)) - - return exe_args, adapted_args - - def __call__(self, *args, **kwargs): - exe_args, adapted_args = self.generate_execution_args( - args, kwargs, self.args_spec - ) - - self.run_compiled_program(exe_args) - - # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`. - def get_invoke_packed_args(self, exe_args): - if self.cuda_modules: - exe_args += self.cuda_modules.args - packed_args = (ctypes.c_void_p * len(exe_args))() - for argNum in range(len(exe_args)): - packed_args[argNum] = exe_args[argNum] - return packed_args - - def run_compiled_program(self, exe_args): - if self.jit_time_profiling: - profiler = timer(enable=True) - try: - packed_args = profiler(self.get_invoke_packed_args)(exe_args) - profiler(self.capi_func)(packed_args) - except Exception as e: - raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) - else: - try: - packed_args = self.get_invoke_packed_args(exe_args) - self.capi_func(packed_args) - except Exception as e: - raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) - - def update_jit_cuda_modules(self, kernel_symbols): - # preload cuda module from compiled cubin in ir and store to jit_executor.kernels. - if len(kernel_symbols) > 0: - extra_args = [] - module = self.ir_module - cuda_kernel_cache = dict() - cuda_driver_version = cuda_helpers.get_driver_version() - for sym in kernel_symbols: - if sym not in cuda_kernel_cache: - log().debug(f"Loading CUDA module for symbol: {sym}") - - # load cuda module/get function pointer from module and cache - def walk_callback(sym, func_sym, cubin_data): - cubin_module = cuda_helpers.load_cubin_module_data(cubin_data) - kernel_ptr = cuda_helpers.get_kernel_function( - cubin_module, func_sym - ) - # Enable non-portable cluster size for CUDA version 11.8 or higher. - if cuda_driver_version >= 11080: - cuda_helpers.set_kernel_attribute( - kernel_ptr, - cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1, - ) - cuda_kernel_cache[sym] = CudaSingleModule( - cubin_module, kernel_ptr - ) - - self.walk_module_and_get_cubin_data(module, sym, walk_callback) - else: - log().debug(f"Symbol {sym} already in cache") - # check if kernel is empty. - if sym in cuda_kernel_cache: - extra_args.append( - ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr()) - ) - # store to the jit result if jit result is cached. - self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args) - - return self - - def _get_escaped_cubin_bytes(self, cubin_data): - """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" - - def ishex(inp): - return ( - inp in range(0x30, 0x3A) - or inp in range(0x61, 0x67) - or inp in range(0x41, 0x47) - ) - - converted = bytearray() - idx = 0 - while idx < len(cubin_data): - # escape the original bytes - if cubin_data[idx] == 0x5C: - # if data of idx is b'\\' - if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]): - converted += bytearray.fromhex( - cubin_data[idx + 1 : idx + 3].decode() - ) - idx += 3 - elif cubin_data[idx + 1] == 0x5C: - converted.append(cubin_data[idx]) - idx += 2 - else: - # no escape, directly write - converted.append(cubin_data[idx]) - idx += 1 - return bytes(converted) - - def walk_module_and_get_cubin_data(self, module, sym, callback): - """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" - - def walk_gpu_binary_op(op): - if op.name != "gpu.binary": - return ir.WalkResult.ADVANCE - s = io.BytesIO() - op.write_bytecode(s) - cubin_data = s.getvalue() - if sym.encode() not in cubin_data: - return ir.WalkResult.ADVANCE - - if ( - "kernels" != op.opview.sym_name.value - and sym != op.opview.sym_name.value - ): - return ir.WalkResult.ADVANCE - # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir - func_sym = sym - if sym == op.opview.sym_name.value and not sym.endswith("_kernel"): - func_sym = sym.rsplit("_", 1)[0] - - cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0] - cubin_data = self._get_escaped_cubin_bytes(cubin_data) - callback(sym, func_sym, cubin_data) - return ir.WalkResult.ADVANCE - - module.operation.walk(walk_gpu_binary_op) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py deleted file mode 100644 index ccc475fdda59450f07c35ae244d6223446470c6d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides a runtime utility functions that are needed for -the DSL. -""" - -from . import dlpack_types -from . import cuda -from . import jit_arg_adapters - -__all__ = [ - "dlpack_types", - "cuda", - "jit_arg_adapters", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py deleted file mode 100644 index 97ae778c0cd5ae19d20fac8e045e2021832f5bbc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py +++ /dev/null @@ -1,476 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides CUDA Python helper functions -""" - - -from functools import lru_cache -from dataclasses import dataclass -from typing import List, Optional -import numpy as np -import os -import ctypes - -import cuda.bindings.driver as cuda -import cuda.bindings.nvrtc as nvrtc - -# MLIR imports -from ..._mlir import ir -from ..._mlir.dialects import gpu - -# Local module imports -from ..utils.logger import log as _log -from ..common import * -from .jit_arg_adapters import JitArgAdapterRegistry - - -# ============================================================================= -# Utils -# ============================================================================= - - -def _cudaGetErrorEnum(error): - if isinstance(error, cuda.CUresult): - err, name = cuda.cuGetErrorName(error) - return name if err == cuda.CUresult.CUDA_SUCCESS else "" - elif isinstance(error, nvrtc.nvrtcResult): - return nvrtc.nvrtcGetErrorString(error)[1] - else: - raise DSLRuntimeError("Unknown error type: {}".format(error)) - - -def _get_gpu_arch_info(major, minor): - """Get GPU architecture information and compatibility details.""" - gpu_arch_map = { - (7, 0): ("Volta", "sm_70", ["sm_70"]), # V100 - (7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX - (8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100 - (8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series - (8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series - (8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40 - (9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100 - (10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200 - } - return gpu_arch_map.get( - (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"]) - ) - - -def get_compute_capability_major_minor(device_id: int = 0): - """ - Returns the compute capability of the CUDA device as a tuple of (major, minor). - For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell. - Returns None on failure. - """ - try: - checkCudaErrors(cuda.cuInit(0)) - device = checkCudaErrors(cuda.cuDeviceGet(device_id)) - major = checkCudaErrors( - cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, - device, - ) - ) - minor = checkCudaErrors( - cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, - device, - ) - ) - return major, minor - except RuntimeError as e: - _log().info(f"Failed to get CUDA compute capability: {e}") - return None, None - - -@dataclass -class DeviceInfo: - """Data class to store CUDA device information.""" - - device_count: int = 0 - current_device: int = 0 - device_name: Optional[str] = None - major_version: Optional[int] = None - minor_version: Optional[int] = None - arch_name: Optional[str] = None - sm_arch: Optional[str] = None - compatible_archs: Optional[List[str]] = None - memory_gb: Optional[float] = None - target_arch: Optional[str] = None - error_message: Optional[str] = None - initialization_failed: bool = False - - def pretty_str(self) -> str: - """ - Convert DeviceInfo to a formatted string for display. - """ - info = "" - - if self.initialization_failed: - return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}" - - if self.error_message: - return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}" - - if self.device_count > 0: - info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n" - - if self.major_version is not None and self.minor_version is not None: - info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n" - info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n" - - if self.memory_gb is not None: - info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n" - - else: - info += f"- Compute capability: unknown\n" - info += f"- SM arch: unknown{Colors.RESET}\n" - else: - info += f"- No devices available\n" - - return info - - -def get_device_info() -> DeviceInfo: - """ - Get detailed information about CUDA devices. - Returns a DeviceInfo dataclass with device information. - """ - device_info = DeviceInfo() - - # Initialize CUDA if not already initialized - try: - result = cuda.cuInit(0) - if result[0].value: # Check for error - device_info.initialization_failed = True - return device_info - except: - pass - - try: - # Get device count - result = cuda.cuDeviceGetCount() - device_info.device_count = result[1] if result[0].value == 0 else 0 - - if device_info.device_count > 0: - # Get current device - try: - result = cuda.cuCtxGetDevice() - if result[0].value == 0: - device_info.current_device = result[1] - except: - pass - - # Get device name - try: - name_result = cuda.cuDeviceGetName(100, device_info.current_device) - if name_result[0].value == 0: - device_info.device_name = name_result[1] - except: - pass - - # Get compute capability and architecture info - try: - major, minor = get_compute_capability_major_minor( - device_info.current_device - ) - - # Check if we successfully got the compute capability - if major is not None and minor is not None: - device_info.major_version = major - device_info.minor_version = minor - - arch_name, sm_arch, compatible_archs = _get_gpu_arch_info( - device_info.major_version, device_info.minor_version - ) - - device_info.arch_name = arch_name - device_info.sm_arch = sm_arch - device_info.compatible_archs = compatible_archs - - # Get memory info - try: - total_mem = cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY, - device_info.current_device, - ) - if total_mem[0].value == 0: - device_info.memory_gb = total_mem[1] / ( - 1024 * 1024 * 1024 - ) # Convert to GB - except: - pass - - except Exception as e: - pass # Compute capability info will remain None - - except Exception as e: - device_info.error_message = str(e) - - return device_info - - -def checkCudaErrors(result): - """Check CUDA errors and provide detailed error messages.""" - if result[0].value: - error_code = result[0].value - error_name = _cudaGetErrorEnum(result[0]) - - raise DSLCudaRuntimeError(error_code, error_name) - - if len(result) == 1: - return None - elif len(result) == 2: - return result[1] - else: - return result[1:] - - -# ============================================================================= -# Driver Helpers -# ============================================================================= - - -@lru_cache(maxsize=1) -def initialize_cuda_context(device_id: int = 0, flags: int = 0): - """ - Initializes the CUDA context for a specified device. - """ - # Initialize CUDA Driver API - _log().info(f"cuInit {flags}") - checkCudaErrors(cuda.cuInit(flags)) - # Retrieve handle for device - _log().info(f"cuDeviceGet {device_id}") - cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id)) - _log().info(f"{cuDevice} <-- cuDeviceGet") - # Create context - _log().info(f"cuCtxCreate {0} {cuDevice}") - if cuda.CUDA_VERSION >= 13000: - # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 - # and v3 API has been removed from CTK 13. - # See https://github.com/NVIDIA/cuda-python/pull/792 - context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) - else: - context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) - _log().info(f"{context} <-- cuCtxCreate") - - return context - - -def load_cubin_module(cubin_file): - """ - Loads a CUBIN file and returns the module. - """ - # Load CUBIN file as binary data - _log().info(f"read cubin {cubin_file}") - with open(cubin_file, "rb") as f: - cubin_data = f.read() - # Load module data - _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") - module = checkCudaErrors( - cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) - ) - return module - - -def unload_cubin_module(module): - """ - Unloads a CUBIN module. - """ - _log().info(f"cuModuleUnload {module}") - checkCudaErrors(cuda.cuModuleUnload(module)) - - -def load_cubin_module_data(cubin_data): - """ - Loads a CUBIN from data and returns the module. - """ - # Load module data - _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") - module = checkCudaErrors( - cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) - ) - return module - - -def get_kernel_function(module, kernel_name): - """ - Retrieves the kernel function from the module. - """ - _log().info(f"cuModuleGetFunction {module} {kernel_name}") - kernel = checkCudaErrors( - cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8")) - ) - _log().info(f"{kernel} <-- cuModuleGetFunction") - return kernel - - -def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): - """ - Launches the CUDA kernel. - """ - _log().info( - f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}" - ) - checkCudaErrors( - cuda.cuLaunchKernel( - kernel, - grid_dims[0], - grid_dims[1], - grid_dims[2], - block_dims[0], - block_dims[1], - block_dims[2], - smem_size, # Shared memory size - stream, - kernel_args, - 0, # Extra parameters - ) - ) - - -def stream_sync(stream): - """ - Synchronizes the CUDA stream. - """ - _log().info(f"cuStreamSynchronize {stream}") - checkCudaErrors(cuda.cuStreamSynchronize(stream)) - - -def stream_create(id=0): - """ - Creates the CUDA stream. - """ - _log().info(f"cuStreamCreate {id}") - stream = checkCudaErrors(cuda.cuStreamCreate(id)) - _log().info(f"{stream} <-- cuStreamCreate") - return stream - - -def stream_destroy(stream): - """ - Destroys the CUDA stream. - """ - _log().info(f"cuStreamDestroy {stream}") - checkCudaErrors(cuda.cuStreamDestroy(stream)) - - -def context_destroy(context): - """ - Destroys the CUDA context. - """ - _log().info(f"cuCtxDestroy {context}") - checkCudaErrors(cuda.cuCtxDestroy(context)) - - -def allocate(size_in_bytes: int, stream=None): - """ - Allocate device memory based on numpy host array size. - """ - _log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream) - if stream is None: - device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) - else: - device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream)) - _log().info("Allocated [%s]", device_memory) - return device_memory - - -def deallocate(device_pointer, stream=None): - """ - Deallocate the specified device memory pointer. - """ - _log().info( - "Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream - ) - if stream is None: - checkCudaErrors(cuda.cuMemFree(device_pointer)) - else: - checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream)) - - -def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): - """ - Copy data from host to device memory. - """ - _log().info( - "Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]", - hex(host_pointer), - hex(int(device_pointer)), - size_in_bytes, - stream, - ) - if stream is None: - checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes)) - else: - checkCudaErrors( - cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream) - ) - - -def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): - """ - Copy data from device to host memory. - """ - _log().info( - "Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]", - hex(int(device_pointer)), - hex(host_pointer), - size_in_bytes, - stream, - ) - if stream is None: - checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes)) - else: - checkCudaErrors( - cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream) - ) - - -def default_stream(): - return cuda.CUstream(0) - - -def get_driver_version(): - """ - Returns the CUDA driver version. - """ - return checkCudaErrors(cuda.cuDriverGetVersion()) - - -def set_kernel_attribute(kernel, attribute, value): - """ - Sets a CUDA kernel attribute. - """ - return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value)) - - -@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream) -class StreamAdapter: - """ - Convert a CUDA stream to a stream representation for JIT arg generation. - """ - - def __init__(self, arg): - self._arg = arg - self._c_pointer = self._arg.getPtr() - - def __new_from_mlir_values__(self, values): - assert len(values) == 1 - return values[0] - - def __c_pointers__(self): - return [self._c_pointer] - - def __get_mlir_types__(self): - return [gpu.AsyncTokenType.get()] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py deleted file mode 100644 index 5addb275b12f2b18e109b0592a87f3044d2fe595..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import copy - -from . import cuda as cuda_helpers -from .tensor_descriptor import * -from ..common import * - - -def allocate(tensor: TensorDescriptor, stream=None): - """ - Allocates GPU memory - """ - if tensor._check_is_managed_by_framework(): - raise DSLRuntimeError( - "GPU tensors are managed by the framework and cannot be modified." - ) - if not tensor.device_pointer is None: - raise DSLRuntimeError("Tensor is already allocated on the device.") - - tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream) - - log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) - - -def deallocate(tensor: TensorDescriptor, stream=None): - """ - Deallocates GPU memory - """ - if tensor._check_is_managed_by_framework(): - raise DSLRuntimeError( - "GPU tensors are managed by the framework and cannot be modified." - ) - if tensor.device_pointer is None: - raise DSLRuntimeError("Tensor is not allocated on the device.") - - log().info( - "Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer - ) - - cuda_helpers.deallocate(tensor.device_pointer, stream) - tensor.device_pointer = None - - -def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None): - """ - Copies data from host memory to the GPU memory. - If do_allocate is True, it first calls allocate - """ - log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) - if do_allocate: - allocate(tensor, stream) - cuda_helpers.memcpy_h2d( - tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream - ) - log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) - return tensor - - -def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None): - """ - Copies data from GPU memory back to the host. - If do_deallocate is True, it calls deallocate - """ - log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) - if tensor._check_is_managed_by_framework(): - raise DSLRuntimeError( - "GPU tensors are managed by the framework and cannot be modified." - ) - if tensor.device_pointer is None: - raise DSLRuntimeError("Tensor is not allocated on the device.") - - cuda_helpers.memcpy_d2h( - tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream - ) - if do_deallocate: - deallocate(tensor, stream) - log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) - - -def to_gpu(tensor, stream=None) -> TensorDescriptor: - """ - Copies the tensor to the GPU memory from Host memory - """ - if isinstance(tensor, TensorDescriptor): - new_tensor = copy.copy(tensor) - copy_to_gpu(new_tensor, stream=stream) - return new_tensor - - if TensorDescriptor.can_transformed_to_dlpack(tensor): - new_tensor = TensorDescriptor(tensor) - copy_to_gpu(new_tensor, stream=stream) - return new_tensor - - raise DSLRuntimeError("Unsupported type") - - -def from_gpu(tensor, stream=None) -> TensorDescriptor: - """ - Copies the tensor to the GPU memory from Host memory - """ - if isinstance(tensor, TensorDescriptor): - new_tensor = copy.copy(tensor) - copy_from_gpu(new_tensor, stream=stream) - return new_tensor - - if TensorDescriptor.can_transformed_to_dlpack(tensor): - new_tensor = TensorDescriptor(tensor) - copy_from_gpu(new_tensor, stream=stream) - return new_tensor - - raise DSLRuntimeError("Unsupported type") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py deleted file mode 100644 index 168c2a9953f74b45cadfcbb6562f89d1bb35cd6d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides helper structs for dlpack. -DLPack is an open standard for in-memory tensor structures, enabling -seamless sharing of tensors across different frameworks. -Learn more at: https://github.com/dmlc/dlpack -""" - -import ctypes -import enum - - -class DLDeviceType(enum.IntEnum): - """Enums for device types based on the DLPack specification.""" - - kDLCPU = 1 - kDLGPU = 2 - kDLCPUPinned = 3 - - -class DLDataTypeCode: - """Enums for data type codes based on the DLPack specification. - - see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h - """ - - kDLInt = 0 - kDLUInt = 1 - kDLFloat = 2 - kDLOpaqueHandle = 3 - kDLBfloat = 4 - kDLComplex = 5 - kDLBool = 6 - - -class DLDevice(ctypes.Structure): - """Structure representing the device information in DLPack.""" - - _fields_ = [ - ("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc. - ("device_id", ctypes.c_int), # Device ID (e.g., GPU ID) - ] - - -class DLDataType(ctypes.Structure): - """Structure representing the data type in DLPack.""" - - _fields_ = [ - ("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat) - ("bits", ctypes.c_uint8), # Number of bits per value - ("lanes", ctypes.c_uint16), # Number of lanes - ] - - -class DLTensor(ctypes.Structure): - """Structure representing the DLTensor in DLPack.""" - - _fields_ = [ - ("data", ctypes.c_void_p), # Pointer to tensor data - ("device", DLDevice), # Device info - ("ndim", ctypes.c_int), # Number of dimensions - ("dtype", DLDataType), # Data type - ("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor - ("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor - ("byte_offset", ctypes.c_uint64), # Byte offset to tensor data - ] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py deleted file mode 100644 index eb998d16d8fb4bcf592f17ce0f23a81d6e11bff6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides runtime utilities for JIT argument conversion in DSL. -""" - -from functools import wraps -from typing import get_origin - -# Local modules imports -from ..common import DSLRuntimeError -from ..typing import ( - Constexpr, - Int32, - Float32, - Boolean, -) - - -def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func): - """ - Check if the argument spec is a constexpr. - """ - - def _is_reserved_python_func_arg(arg_index, arg_name, func): - """ - Check if the argument is a reserved python function argument. - """ - - if arg_index != 0: - return False - - if arg_name == "self": - return True - - is_classmethod = isinstance(func, classmethod) or ( - hasattr(func, "__func__") and isinstance(func.__func__, classmethod) - ) - return arg_name == "cls" and is_classmethod - - return ( - _is_reserved_python_func_arg(arg_index, arg_name, owning_func) - or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr)) - or (get_origin(arg_spec) is Constexpr) - ) - - -def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func): - """ - Check if the argument is a constexpr. - """ - - def _is_type_argument(arg, arg_annotation): - """ - Check if the argument is a type argument like Type[X] - """ - - return isinstance(arg, type) and ( - arg_annotation is None or get_origin(arg_annotation) is type - ) - - return ( - is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func) - or _is_type_argument(arg, arg_spec) - or arg is None - ) - - -class JitArgAdapterRegistry: - """ - A registry to keep track of the JIT argument adapters. - - An adapter is a callable that converts a Python type to a type with following protocols supported: - - JitArgument - - DynamicExpression - The converted type can then be further processed by DSL to generate arguments for JIT functions. - """ - - # A dictionary with key=type and value=callable - jit_arg_adapter_registry = {} - - @classmethod - def register_jit_arg_adapter(cls, *dargs, **dkwargs): - """ - Register a JIT argument adapter callable - - This can be used as a decorator on any callable like: - - @register_jit_arg_adapter(my_py_type) - def my_adapter_for_my_py_type(arg): - ... - - @register_jit_arg_adapter(my_py_type) - class MyAdapterForMyPythonType: - ... - - The adapters are registered per type. If a type is already registerd, an error will be raised. - """ - - def decorator(*dargs, **dkwargs): - darg_python_ty = dargs[0] - - @wraps(darg_python_ty) - def wrapper(*args, **kwargs): - if len(args) != 1 or not callable(args[0]): - raise DSLRuntimeError( - "a callable must be provided for registering JIT argument adapter" - ) - adapter = args[0] - - if darg_python_ty in cls.jit_arg_adapter_registry: - raise DSLRuntimeError( - f"JIT argument adapter for {darg_python_ty} is already registered!", - context={ - "Registered adapter": cls.jit_arg_adapter_registry[ - darg_python_ty - ], - "Adapter to be registered": adapter, - }, - ) - cls.jit_arg_adapter_registry[darg_python_ty] = adapter - return adapter - - return wrapper - - if len(dargs) > 0: - return decorator(*dargs, **dkwargs) - else: - raise DSLRuntimeError( - "a Python type must be provided for registering JIT argument adapter" - ) - - @classmethod - def get_registered_adapter(cls, ty): - """ - Get the registered JIT argument adapter for the given type. - """ - return cls.jit_arg_adapter_registry.get(ty, None) - - -# ============================================================================= -# JIT Argument Adapters -# ============================================================================= - - -@JitArgAdapterRegistry.register_jit_arg_adapter(int) -@JitArgAdapterRegistry.register_jit_arg_adapter(float) -@JitArgAdapterRegistry.register_jit_arg_adapter(bool) -def _convert_python_scalar(arg): - """ - Convert a Python scalar to a DSL type. - """ - conversion_map = { - int: Int32, - float: Float32, - bool: Boolean, - } - return conversion_map.get(type(arg))(arg) - - -@JitArgAdapterRegistry.register_jit_arg_adapter(tuple) -@JitArgAdapterRegistry.register_jit_arg_adapter(list) -def _convert_python_sequence(arg): - """ - Go through each element in the sequence and convert it to a type that can be - further processed by DSL to generate the corresponding JIT argument(s). - """ - adapted_arg = [] - for elem in arg: - adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem)) - if adapter is not None: - converted_elem = adapter(elem) - adapted_arg.append(converted_elem) - else: - # If no registered adapter is found, just return the original element - adapted_arg.append(elem) - - assert len(adapted_arg) == len(arg) - return type(arg)(adapted_arg) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py deleted file mode 100644 index 1a992ef68293d6f969ab551b6321c3696c961037..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -# Helpers -import itertools, operator -import ctypes -from . import dlpack_types as _dpack -from .dlpack_runtime import ( - dlpack_to_tensor_desc, - get_tensor_desc_data_ptr, - get_tensor_desc_is_in_device, - get_tensor_desc_element_type, - get_tensor_desc_shape, - get_tensor_desc_stride, - get_tensor_desc_element_size_in_bytes, - get_tensor_desc_ndim, - get_tensor_desc_dtype_code, - get_tensor_desc_dtype_bits, - get_tensor_desc_device_type, - get_tensor_desc_device_id, -) - -from ..utils.logger import log -from ..common import * -from ..typing import ( - Boolean, - Float8E5M2, - Int64, - Int32, - Int16, - Int8, - Uint64, - Uint32, - Uint16, - Uint8, - Float64, - Float32, - Float16, - BFloat16, -) - - -class TensorDescriptor: - def __init__(self, tensor): - """Initialize with a tensor that supports the DLPack protocol. - - Args: - tensor: Any tensor object that implements __dlpack__ and __dlpack_device__ - """ - - self.tensor = tensor - self._capsule = dlpack_to_tensor_desc(tensor) - - self.data_ptr = get_tensor_desc_data_ptr(self._capsule) - self.device_type = get_tensor_desc_device_type(self._capsule) - self.device_type = _dpack.DLDeviceType(self.device_type) - - if self.device_type == _dpack.DLDeviceType.kDLGPU: - self.device_pointer = self.data_ptr - elif self.device_type == _dpack.DLDeviceType.kDLCPU: - self.device_pointer = None - else: - raise DSLRuntimeError( - f"DLPack device type is not supported {self.dl_tensor.device.device_type}" - ) - - log().info("TensorDescriptor is created = [%s]", self) - - @staticmethod - def can_transformed_to_dlpack(dl_tensor): - if not hasattr(dl_tensor, "__dlpack__") or not hasattr( - dl_tensor, "__dlpack_device__" - ): - return False - return True - - @property - def is_in_device(self): - """Check if the tensor is stored on a device.""" - return not self.device_pointer is None - - @property - def device_id(self): - """Return device id where tensor resides.""" - if self.is_in_device: - return get_tensor_desc_device_id(self._capsule) - return -1 - - @property - def element_type(self): - """Return the corresponding Python type based on DLPack dtype metadata.""" - str_element_type = get_tensor_desc_element_type(self._capsule) - dtype_map = { - # bool is 8bit from numpy and torch - "Bool": Boolean, - "Int64": Int64, - "Int32": Int32, - "Int16": Int16, - "Int8": Int8, - "UInt64": Uint64, - "UInt32": Uint32, - "UInt16": Uint16, - "UInt8": Uint8, - "Float64": Float64, - "Float32": Float32, - "Float16": Float16, - "BFloat16": BFloat16, - "Float8E5M2": Float8E5M2, - } - - if str_element_type not in dtype_map: - raise KeyError( - f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}" - ) - - return dtype_map[str_element_type] - - @property - def shape(self): - """Return the shape of the tensor.""" - return get_tensor_desc_shape(self._capsule) - - @property - def rank(self): - """Return the rank of the tensor.""" - return get_tensor_desc_ndim(self._capsule) - - @property - def strides(self): - """Return the rank of the tensor.""" - return get_tensor_desc_stride(self._capsule) - - @property - def element_size_in_bytes(self): - """Calculate the element size in bytes of the DLPack tensor.""" - return get_tensor_desc_element_size_in_bytes(self._capsule) - - @property - def size_in_bytes(self): - """Calculate the total size in bytes of the DLPack tensor.""" - # Calculate the number of elements using the shape - ndim = get_tensor_desc_ndim(self._capsule) - shape = get_tensor_desc_shape(self._capsule) - num_elements = 1 - for i in range(ndim): - num_elements *= shape[i] - - # Total bytes - total_bytes = self.element_size_in_bytes * num_elements - return total_bytes - - def __str__(self): - """Return a compact string representation of the device_tensor with a tensor prefix.""" - # Extract shape - shape = "x".join(map(str, self.shape)) - - # Extract dtype - dtype_code = get_tensor_desc_dtype_code(self._capsule) - dtype_bits = get_tensor_desc_dtype_bits(self._capsule) - dtype = ( - f"i{dtype_bits}" - if dtype_code == _dpack.DLDataTypeCode.kDLInt - else f"f{dtype_bits}" - ) - - # Extract device - device_type = "cpu" if not self.is_in_device else "gpu" - - return f"tensor<{shape}x{dtype}>_{device_type}" - - def _check_is_managed_by_framework(self): - """ - Ensure the tensor is not managed by the framework (e.g., GPU tensor). - Raises an exception if the tensor is framework-managed. - """ - return self.device_type == _dpack.DLDeviceType.kDLGPU - - @staticmethod - def is_compatible(maybe_tensor_descriptor) -> bool: - """Check if the object is a TensorDescriptor or can be converted to one.""" - return isinstance( - maybe_tensor_descriptor, TensorDescriptor - ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) - - -def from_tensor(tensor) -> TensorDescriptor: - """Create a TensorDescriptor from a tensor object.""" - return TensorDescriptor(tensor) - - -def to_tensor(tensor_descriptor: TensorDescriptor): - """Return tensor object from tensor descriptor.""" - return tensor_descriptor.tensor diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py deleted file mode 100644 index b46cff6de8176217f38af05b8604716c34aae009..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py +++ /dev/null @@ -1,1962 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import ctypes -import numpy as np -import operator -from typing_extensions import deprecated -from functools import reduce -from typing import ( - Generic, - Protocol, - Union, - Any, - List, - Type, - TypeVar, - overload, - runtime_checkable, - get_origin, -) -from types import FunctionType -from dataclasses import dataclass -from abc import ABC, abstractmethod - -from .common import * -from .ast_helpers import const_expr -from ._mlir_helpers import arith as arith_helper, lru_cache_ir -from ._mlir_helpers.arith import ArithValue - -from .._mlir import ir -from .._mlir.extras import types as T -from .._mlir.dialects import arith, math - -# ============================================================================= -# Dynamic Expression Protocol -# ============================================================================= - - -@runtime_checkable -class DynamicExpression(Protocol): - """Protocol defining the interface for object holding dynamic values in the DSL. - - This protocol enables classes to represent dynamic values in the DSL. Classes implementing - this protocol can be used in JIT-compiled functions and dynamic value generation. - - It is required for custom data types to work correctly with following JIT features: - * as function argument to call another JIT function from JIT function - * as return value from JIT function - * for constructions like if-else, while-loop, etc. - - :param value: The MLIR operation result value to initialize the object with - :type value: ir.Value - - **Required Methods** - - * ``__extract_mlir_values__``: Extract MLIR values from the object - * ``__new_from_mlir_values__``: Create new instance from MLIR values - - **Implementation Example** - - To implement a custom data type that works with the DSL: - - .. code-block:: python - - class CustomData(metaclass=DslType): - def __init__(self, int_value): - self.int_value = int_value - - def __extract_mlir_values__(self): - return [self.int_value] - - def __new_from_mlir_values__(self, values): - return CustomData(values[0]) - - **Usage in JIT Functions** - - When used in JIT-compiled functions, the DSL automatically extracts MLIR values: - - .. code-block:: python - - @jit - def caller(): - x = CustomData(1) - return foo(x) - - This generates MLIR like: - - .. code-block:: mlir - - func @caller() -> i32 { - %0 = func.call @foo(%arg0) : (i32) -> i32 - return %0 : i32 - } - """ - - def __extract_mlir_values__(self): - """Extract MLIR values from this object. - - :return: List of MLIR values representing this object's data - :rtype: List[ir.Value] - """ - raise NotImplementedError - - def __new_from_mlir_values__(self, values): - """Create a new instance from MLIR values. - - :param values: List of MLIR values to construct the object from - :type values: List[ir.Value] - :return: New instance of the implementing class - :rtype: Any - """ - raise NotImplementedError - - -@runtime_checkable -class JitArgument(Protocol): - """ - Protocol class defining the interface for JIT function argument generation. - - This protocol enables classes to provide the necessary information for generating - JIT function arguments and allow the DSL JIT executor to call JIT compiled functions. - - **Required Methods** - - * ``__c_pointers__``: Returns ctypes pointers for runtime execution - * ``__get_mlir_types__``: Returns MLIR types for function definition - * ``__new_from_mlir_values__``: Creates new instances from MLIR values - - **Example** - - .. code-block:: python - - class CustomData: - def __init__(self, int_value, ...): - self.int_value = int_value - ... - - def __c_pointers__(self): - return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...] - - def __get_mlir_types__(self): - return [ir.IntegerType.get(32), ...] - - def __new_from_mlir_values__(self, values): - return CustomData(values[0], ...) - - @jit - def foo(x: CustomData): - a = x.int_value + 1 - ... - - # `CustomData` is an argument of `foo` - foo(CustomData(1, ...)) - - When called like ``y = foo(x)``, the following steps occur: - - 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__`` - - .. code-block:: mlir - - func.func @foo(%arg0: i32, ...) { - ... - - return - } - - 2. JIT function can't use values from Python, so it needs to reconstruct the object from - MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`. - - Following code demonstrates how JIT compiler reconstructs the object and pass to Python. - - .. code-block:: python - - # Implementation of IR tracing - new_x = CustomData(ir.Value(%arg0), ...) - y = foo(new_x) - # `x.int_value` is %arg0 rather than `c1` defined by Python. - - 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__`` - pointing to the underlying data object passing to JIT compiled function. - - .. code-block:: python - - jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...])) - """ - - def __c_pointers__(self): - """ - Generate a list of ctypes pointers for the current object. - - :return: List of ctypes pointers - :rtype: List[ctypes.c_void_p] - """ - raise NotImplementedError - - def __get_mlir_types__(self): - """ - Generate a list of MLIR types for the current object. - - :return: List of MLIR types - :rtype: List[ir.Type] - """ - raise NotImplementedError - - def __new_from_mlir_values__(self, values): - """ - Create a new object from MLIR values. - - :param values: List of MLIR values - :type values: List[ir.Value] - :return: A new object that represents the given MLIR values - :rtype: Any - """ - raise NotImplementedError - - -def get_c_pointers(obj): - """ - Given the `obj`, recursively go through it to extract all contained C pointers - """ - if hasattr(obj, "__c_pointers__"): - return obj.__c_pointers__() - elif isinstance(obj, (tuple, list)): - return sum((get_c_pointers(x) for x in obj), []) - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in get_c_pointers to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - return [] - - -def get_mlir_types(obj): - """ - Given the `obj`, recursively go through it to extract all contained MLIR types - """ - if hasattr(obj, "__get_mlir_types__"): - return obj.__get_mlir_types__() - elif hasattr(obj, "__extract_mlir_values__"): - return [v.type for v in obj.__extract_mlir_values__()] - elif isinstance(obj, ir.Value): - return [obj.type] - elif isinstance(obj, (tuple, list)): - return sum((get_mlir_types(x) for x in obj), []) - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in get_mlir_types to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - return [] - - -class DslType(type): - """Metaclass for all DSL types in the system. - - This metaclass provides type system infrastructure for DSL types, handling MLIR - type mappings and NumPy type conversions. - - All data types in DSL must provide the following methods: - - :param mlir_type: Corresponding MLIR type for this DSL type - :type mlir_type: Any, optional - :param is_abstract: Whether this type is abstract, defaults to False - :type is_abstract: bool, optional - - **Required Methods** - - * ``__str__`` (classmethod): Return string representation of the type - * ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function - * ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance - * ``__extract_mlir_values__``: Return list of MLIR values contained in the instance - * ``__new_from_mlir_values__``: Return a new instance from list of MLIR values - - **Attributes** - - :ivar _ir: MLIR provider - :vartype _ir: Any - :ivar _T: MLIR Type system provider - :vartype _T: Any - - **Properties** - - :property mlir_type: Returns the corresponding MLIR type for this DSL type - :type mlir_type: Any - - """ - - _is_abstract: bool - - def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs): - new_cls = super().__new__(cls, name, bases, attrs) - - new_cls._is_abstract = is_abstract - - return new_cls - - @property - def is_abstract(cls): - return cls._is_abstract - - -class NumericMeta(DslType): - """Metaclass for numeric types providing width and numpy dtype information. - - :param width: Bit width of the numeric type, defaults to 8 - :type width: int - :param np_dtype: Corresponding NumPy dtype - :type np_dtype: numpy.dtype, optional - :param mlir_type: Corresponding MLIR type - :type mlir_type: Any, optional - :param is_abstract: Whether the type is abstract, defaults to False - :type is_abstract: bool, optional - - :ivar width: Bit width of the numeric type - :type width: int - :ivar _np_dtype: Corresponding NumPy dtype - :type _np_dtype: Union[numpy.dtype, None] - - :property numpy_dtype: Returns the corresponding NumPy dtype - :rtype numpy_dtype: numpy.dtype - """ - - width: int - - # Placeholder type - _mlir_type = Any - _np_dtype: Union[np.dtype, None] - - def __new__( - cls, - name, - bases, - attrs, - width=8, - np_dtype=None, - mlir_type=None, - is_abstract=False, - **kwargs, - ): - def _extract_mlir_values(self): - return [self.ir_value()] - - def _new_from_mlir_values(self, values: list) -> "Numeric": - res_ty = type(self) - return res_ty(values[0]) - - new_attrs = { - "__extract_mlir_values__": _extract_mlir_values, - "__new_from_mlir_values__": _new_from_mlir_values, - } - new_cls = super().__new__( - cls, - name, - bases, - new_attrs | attrs, - is_abstract=is_abstract, - **kwargs, - ) - - if mlir_type is not None: - new_cls._mlir_type = staticmethod(mlir_type) - - new_cls.width = width - new_cls._np_dtype = np_dtype - return new_cls - - @property - def numpy_dtype(cls): - return cls._np_dtype - - @property - def is_integer(cls) -> bool: ... - - @property - def is_float(cls) -> bool: ... - - def is_same_kind(cls, other: Type) -> bool: - return cls.is_integer == other.is_integer or cls.is_float == other.is_float - - @staticmethod - def from_python(value: Any) -> Type["Numeric"]: - """ - Deduce the DSL type from a Python value. - """ - if isinstance(value, int): - return Int32 - elif isinstance(value, float): - return Float32 - elif isinstance(value, bool): - return Boolean - raise DSLRuntimeError( - f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}" - ) - - @property - def mlir_type(cls): - return cls._mlir_type() # type: ignore - - -Value = TypeVar("Value") - - -def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric": - """Cast an object to the specified numeric type. - - :param obj: Object to be cast - :type obj: Union[bool, int, float, Value] - :param type_: Target numeric type - :type type_: Type[Numeric] - :raises TypeError: If casting to an abstract type or unsupported type conversion - :return: Object cast to the target numeric type - :rtype: Numeric - - Example:: - >>> x = cast(5, Int32) # Cast integer to Int32 - >>> y = cast(3.14, Float32) # Cast float to Float32 - """ - if type_.is_abstract: - if not isinstance(obj, type_): - raise TypeError( - f"can't cast {obj} to {type_}. Pass in concrete type instead, " - "e.g. Int32, Float32, etc." - ) - # If target_type is abstract, and value is instance of target_type, - # then we can return value as is - else: - # Implicit cast based on using annotation type - obj = type_(obj) - return obj - - -# Option 1: use ir.Value as base -# class IntegerMeta(DslType, type(ir.Value)): -class IntegerMeta(NumericMeta): - """Metaclass for integer types providing signedness information. - - :param width: Bit width of the integer type, defaults to 32 - :type width: int - :param signed: Whether the integer type is signed, defaults to True - :type signed: bool - :param mlir_type: Corresponding MLIR type, defaults to None - :type mlir_type: Any, optional - - :ivar signed: Whether the integer type is signed - :vartype signed: bool - :ivar arith: Arithmetic operations interface - :vartype arith: Any - """ - - signed: bool - - def __new__( - cls, - name, - bases, - attrs, - width=32, - signed=True, - mlir_type=None, - is_abstract=False, - ): - if width == 1: - np_dtype = np.bool_ - elif width == 128: - np_dtype = None - elif signed: - np_dtype = getattr(np, f"int{width}") - else: - np_dtype = getattr(np, f"uint{width}") - - def _c_pointers(self): - if width == 1: - c_value = ctypes.c_bool(self.value) - elif signed: - c_value = getattr(ctypes, f"c_int{width}")(self.value) - else: - c_value = getattr(ctypes, f"c_uint{width}")(self.value) - - return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)] - - new_attrs = { - "__c_pointers__": _c_pointers, - } - new_cls = super().__new__( - cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract - ) - new_cls.signed = signed - return new_cls - - def __str__(cls): - return f"{cls.__name__}" - - @property - def is_integer(cls) -> bool: - return True - - @property - def is_float(cls) -> bool: - return False - - @property - def zero(cls) -> int: - return 0 - - @property - def min(cls) -> int: - if cls.signed: - return -(2 ** (cls.width - 1)) - else: - return 0 - - @property - def max(cls) -> int: - if cls.signed: - return 2 ** (cls.width - 1) - 1 - else: - return 2**cls.width - 1 - - def recast_width(cls, width): - type_map = { - 8: Int8, - 16: Int16, - 32: Int32, - 64: Int64, - 128: Int128, - } - if width not in type_map: - raise TypeError(f"Unsupported width: {width}") - return type_map[width] - - -class FloatMeta(NumericMeta): - """Metaclass for floating-point types. - - This metaclass provides type system infrastructure for floating-point types in the DSL, - handling MLIR type mappings and NumPy type conversions. - - :param width: Bit width of the float type, defaults to 32 - :type width: int - :param mlir_type: Corresponding MLIR type, defaults to None - :type mlir_type: Any, optional - :param is_abstract: Whether this is an abstract base class, defaults to False - :type is_abstract: bool, optional - - :ivar _arith: Arithmetic operations interface - :vartype _arith: Any - """ - - _exponent_width: int - _mantissa_width: int - - def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False): - np_dtype = getattr(np, name.lower(), None) - new_cls = super().__new__( - cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract - ) - # Extract exponent and mantissa bits from class name if it follows Float pattern - # For example: Float8E4M3 -> exponent_width=4, mantissa_width=3 - import re - - if not is_abstract: - match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name) - if match: - exp_bits = int(match.group(2)) - mant_bits = int(match.group(3)) - - # Store extracted values as class attributes - new_cls._exponent_width = exp_bits - new_cls._mantissa_width = mant_bits - # Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc. - return new_cls - - def __str__(cls): - return f"{cls.__name__}" - - @property - def is_integer(cls) -> bool: - return False - - @property - def is_float(cls) -> bool: - return True - - @property - def zero(cls) -> float: - return 0.0 - - @property - def inf(cls) -> float: - return float("inf") - - @property - def nan(cls) -> float: - return float("nan") - - @property - def exponent_width(cls) -> int: - return cls._exponent_width - - @property - def mantissa_width(cls) -> int: - return cls._mantissa_width - - def recast_width(cls, width): - type_map = { - 16: Float16, - 32: Float32, - 64: Float64, - } - if width not in type_map: - raise TypeError(f"Unsupported width: {width}") - return type_map[width] - - -def _arith_signless_to_int(a, target_type): - # is_signed: sign of result type - if target_type.width > a.type.width: - # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL - if target_type.signed and a.type.width > 1: - return arith.extsi(target_type.mlir_type, a) - else: - return arith.extui(target_type.mlir_type, a) - elif target_type.width < a.type.width: - return arith.trunci(target_type.mlir_type, a) - else: - return a - - -def _binary_op_type_promote(a, b, promote_bool: bool = False): - """Promote two numeric operands following type promotion rules. - - :param a: First numeric operand - :type a: Numeric - :param b: Second numeric operand - :type b: Numeric - :param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False - :type promote_bool: bool, optional - :raises ValueError: If implicit float promotion is not supported between the given types - :return: Tuple containing promoted operands and their resulting type - :rtype: tuple[Numeric, Numeric, Type[Numeric]] - - Type promotion rules: - 1. If operands are same type and not bools needing promotion: - - No promotion needed, return original types - 2. If either operand is float: - a. If one is float and one is int: - - Convert int to the float type - b. If both are float: - - Promote to higher precision float if width >= 16 - - For same width, promote to more general type (Float32 over TFloat32) - - Otherwise raise ValueError for unsupported promotion - 3. Otherwise, both operands are integers. Integer promotion rules: - a. If promote_bool is True and either operand is bool: - - Promote bool to Int32 for arithmetic operations - - Exceptions for numpy dtype casting: - - array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_) - - What is not supported: - - promotion with narrow precision float types which requires explicit cast by user - """ - a_type = a.dtype - b_type = b.dtype - - # Early return for same types (except when they're bools that need promotion) - if a_type == b_type and not (promote_bool and a_type is Boolean): - return a, b, a_type - - # Handle floating point promotions - if a_type.is_float or b_type.is_float: - # Get highest precision float type based on bitwidth - a_width = getattr(a_type, "width", 0) - b_width = getattr(b_type, "width", 0) - - # If one type is integer, convert it to the float type - if a_type.is_float and not b_type.is_float: - b_type = a_type.recast_width(max(a_width, b_width)) - elif b_type.is_float and not a_type.is_float: - a_type = b_type.recast_width(max(a_width, b_width)) - - # Both are float types - handle precision promotion - if a_width > b_width and a_width >= 16: - res_type = a_type - elif b_width > a_width and b_width >= 16: - res_type = b_type - elif a_width == b_width: - # Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16 - if a_type is Float64 or b_type is Float64: - res_type = Float64 - elif a_type is Float32 or b_type is Float32: - res_type = Float32 - elif a_type is Float16 or b_type is Float16: - res_type = Float16 - else: - raise ValueError( - f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" - ) - else: - raise ValueError( - f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" - ) - - # Only convert if type is different - new_a = a.to(res_type) if a.dtype != res_type else a - new_b = b.to(res_type) if b.dtype != res_type else b - return new_a, new_b, res_type - - # Handle bool promotion for arithmetic operations - if promote_bool: - if a_type is Boolean and b_type is Boolean: - # Only promote to Int32 when both are bool - a = a.to(Int32) - b = b.to(Int32) - a_type = b_type = a.dtype - - # If both were bools, they're now same type (Int32) - if a_type == b_type: - return a, b, a_type - - # Same type, no promotion needed - if a_type == b_type: - return a, b, a_type - - a_signed = a_type.signed - b_signed = b_type.signed - a_width = a_type.width - b_width = b_type.width - - # Mixed signedness case - if a_signed != b_signed: - unsigned_type = a_type if not a_signed else b_type - signed_type = a_type if a_signed else b_type - unsigned_width = a_width if not a_signed else b_width - - if unsigned_width >= signed_type.width: - # Promote both to unsigned of larger width - res_type = unsigned_type - else: - # Promote both to signed of larger width - res_type = signed_type - - new_a = a.to(res_type) if a.dtype != res_type else a - new_b = b.to(res_type) if b.dtype != res_type else b - return new_a, new_b, res_type - - # Same signedness, different width - promote to larger width - if a_width >= b_width: - return a, b.to(a.dtype), a.dtype - else: - return a.to(b.dtype), b, b.dtype - - -def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): - """Wrapper for binary operations on Numeric types. - - This wrapper handles type promotion, operation execution, and result type determination - for binary operations between Numeric types. - - :param op: The binary operation to perform (e.g., operator.add, operator.sub) - :type op: callable - :param emitter: Function that emits the MLIR operation for dynamic values - :type emitter: callable - :param promote_operand: Whether to promote operands to the same type, defaults to True - :type promote_operand: bool, optional - :param promote_bool: Whether to promote boolean results to Boolean type, defaults to False - :type promote_bool: bool, optional - :param flip: Whether to flip the operands when calling the operation, defaults to False - :type flip: bool, optional - - :raises TypeError: When an unsupported operation is attempted on specific numeric types - - .. note:: - Not all operations are supported for all numeric types. In particular: - - - Subtraction is not fully supported for Integer types - - Multiplication, floor division, and modulo operations may have limited support - - Division (truediv) with integer types is not fully supported and converts to Float32 - """ - - def wrapper(lhs, rhs, *, loc=None, ip=None): - orig_lhs_type = type(lhs) - orig_rhs_type = type(rhs) - - # When called directly with self and other - ty = type(lhs) - # Canonicalize to Numeric type for promotion - if not isinstance(rhs, Numeric): - if not isinstance(rhs, (ArithValue, int, float, bool)): - # This allows rhs class to implement __rmul__ - return NotImplemented - - if isinstance(rhs, ArithValue): - if isinstance(rhs.type, ir.VectorType): - return NotImplemented - - rhs = as_numeric(rhs) - - # default result type to left-hand-side - res_type = ty - - if promote_operand: - lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool) - else: - rhs = ty(rhs) - - if op in ( - operator.lt, - operator.le, - operator.gt, - operator.ge, - operator.eq, - operator.ne, - ): - res_type = Boolean - elif op == operator.truediv and isinstance(lhs, Integer): - res_type = Float32 - elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean: - res_type = Boolean - - if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer): - lhs_val = lhs.value.with_signedness(lhs.signed) - else: - lhs_val = lhs.value - - if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer): - rhs_val = rhs.value.with_signedness(rhs.signed) - else: - rhs_val = rhs.value - - if flip: - lhs_val, rhs_val = rhs_val, lhs_val - - # Check if the operation is supported by the operands - res_val = op(lhs_val, rhs_val) - return res_type(res_val, loc=loc, ip=ip) - - return wrapper - - -class Numeric(metaclass=NumericMeta, is_abstract=True): - """Base class for all numeric types in the DSL. - - This class provides the foundation for both Integer and Float types, - implementing basic arithmetic operations. - - :param value: The value to store in the numeric type - :type value: Union[bool, int, float, Value] - - :ivar value: The stored numeric value - :vartype value: Union[bool, int, float, Value] - """ - - def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None): - self.value = value - - def __str__(self) -> str: - # Use member's pretty-str method if member object has method. - # This can be extended in future to have better support for IDE, jupyter notebook, etc. - pretty_str = getattr(self.value, "pretty_str", None) - if pretty_str is not None: - return pretty_str() - else: - return "?" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({repr(self.value)})" - - def __hash__(self): - return hash(type(self).__class__) ^ hash(self.value) - - @property - def dtype(self) -> Type["Numeric"]: - return type(self) - - @overload - def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ... - - @overload - def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ... - - @overload - def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ... - - @overload - def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ... - - @overload - def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ... - - def to(self, dtype: Type, *, loc=None, ip=None): - """Convert this numeric value to another numeric type. - - If the target type is the same as the current type, returns self. - Otherwise, creates a new instance of the target type with the same value. - - :param dtype: The target numeric type to convert to - :type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]] - :return: A new instance of the target type, or self if types match - :rtype: Numeric - :raises TypeError: If trying to convert an MLIR value to a static Python type - :raises TypeError: If trying to convert to unsupported float types like Float8E4M3, - Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN - - .. note:: - - Unsupported destination float types: - - Float8E4M3 - - Float8E4M3B11FNUZ - - Float4E2M1FN - - Float6E3M2FN - - Float6E2M3FN - - Example:: - - .. code-block:: python - - # Convert between DSL numeric types - x = Int32(5) - y = x.to(Float32) # Converts to Float32(5.0) - - # Convert to Python primitive types - # They are considered as static values at JIT time - z = x.to(int) # Returns Python int 5 - w = y.to(float) # Returns Python float 5.0 - - # This will raise a ValueError - mlir_val = arith.constant(T.i32(), 42) - num = Int32(mlir_val) - num.to(int) # ValueError: unable to convert MLIR value to static type: - """ - if dtype in _unsupported_dst_float_types: - raise TypeError(f"Unsupported destination float type: {dtype}") - - if isinstance(dtype, type(self)): - return self - elif isinstance(dtype, NumericMeta): - return dtype(self) - elif dtype is ir.Value: - if isinstance(self.value, (int, float, bool)): - res = arith_helper.const( - self.value, self.dtype.mlir_type, loc=loc, ip=ip - ) - elif isinstance(self.value, ir.Value): - res = self.value - else: - raise ValueError( - f"cannot convert {type(self)} to {dtype}, " - f"self.value is {self.value.type}" - ) - - if not isinstance(res, ArithValue): - raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}") - - return res.with_signedness(getattr(type(self), "signed", None)) - elif dtype in (int, float, bool): - if isinstance(self.value, ir.Value): - raise ValueError( - f"unable to convert {self.value} to static type: {dtype}" - ) - return dtype(self.value) - else: - raise ValueError(f"unable to convert {type(self)} to {dtype}") - - def ir_value(self, *, loc=None, ip=None) -> ir.Value: - return self.to(ir.Value, loc=loc, ip=ip) - - @property - def zero(self) -> "Numeric": ... - - def __dsl_not__(self, *, loc=None, ip=None): - """DSL implementation of Python's `not` operator. - - Returns True if the value is equal to zero, False otherwise. - This matches Python's behavior where any non-zero number is considered True. - - :param loc: The source location information, defaults to None - :type loc: Optional[Location] - :param ip: The insertion point for the operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: The result of the logical not operation - :rtype: Boolean - """ - if isinstance(self.value, (int, float, bool)): - return not self.value - else: - ty = type(self) - zero_val = arith.constant(ty.mlir_type, ty.zero) - return self.__eq__(ty(zero_val), loc=loc, ip=ip) - - def __dsl_and__(self, other, *, loc=None, ip=None): - """DSL implementation of Python's `and` operator. - - Returns the second operand if the first is truthy, otherwise returns the first operand. - A numeric value is considered truthy if it is non-zero. - - :param other: The right-hand operand - :type other: Numeric - :param loc: The source location information, defaults to None - :type loc: Optional[Location] - :param ip: The insertion point for the operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: The result of the logical and operation - :rtype: Boolean - - Example:: - - 5 and 3 -> 3 - 0 and 3 -> 0 - 3 and 0 and ... -> 0 - """ - is_true = self.__dsl_bool__(loc=loc, ip=ip) - - def and_op(lhs, rhs): - if isinstance(lhs, (int, float, bool)): - if isinstance(rhs, (int, float, bool)): - return lhs and rhs - else: - lhs = arith.constant(rhs.type, lhs) - return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) - else: - if isinstance(rhs, (int, float, bool)): - rhs = arith.constant(lhs.type, rhs) - return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) - else: - return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) - - return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __dsl_or__(self, other, *, loc=None, ip=None): - """DSL implementation of Python's `or` operator. - - Returns the first operand if it is truthy, otherwise returns the second operand. - A numeric value is considered truthy if it is non-zero. - - :param other: The right-hand operand - :type other: Numeric - :param loc: The source location information, defaults to None - :type loc: Optional[Location] - :param ip: The insertion point for the operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: The result of the logical or operation - :rtype: Boolean - - Example:: - - 5 or 3 -> 5 - 0 or 3 -> 3 - 3 or 0 -> 3 - """ - is_true = self.__dsl_bool__(loc=loc, ip=ip) - - def or_op(lhs, rhs): - if isinstance(lhs, (int, float, bool)): - if isinstance(rhs, (int, float, bool)): - return lhs or rhs - else: - lhs = arith.constant(rhs.type, lhs) - return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) - else: - if isinstance(rhs, (int, float, bool)): - rhs = arith.constant(lhs.type, rhs) - return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) - else: - return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) - - return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean": - """DSL implementation of Python's __bool__ method. - - Returns a Boolean indicating whether this value is considered truthy. - For numeric types, returns True if the value is non-zero. - - :param loc: The source location information, defaults to None - :type loc: Optional[Location] - :param ip: The insertion point for the operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: True if this value is truthy (non-zero), False otherwise - :rtype: Boolean - """ - zero = type(self).zero - return self.__ne__(zero, loc=loc, ip=ip) - - def __bool__(self): - if isinstance(self.value, (int, float, bool)): - return bool(self.value) - else: - raise DSLRuntimeError( - f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.", - suggestion=[ - "Decorate the parent function with `jit` decorator and with `preprocess` enabled.", - "Ensure not using patterns that DSL does not support.", - "Otherwise, please file a bug report.", - ], - ) - - def __index__(self): - if isinstance(self.value, (int, float, bool)): - return self.value - else: - raise DSLRuntimeError( - f"'{type(self.value)}' object cannot be interpreted as an integer", - suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator", - ) - - def __neg__(self, *, loc=None, ip=None): - if isinstance(self, (bool, int, float)): - return type(self)(-self.value) # type: ignore - else: - return type(self)(-self.value, loc=loc, ip=ip) # type: ignore - - @staticmethod - def _from_python_value(value): - if isinstance(value, Numeric): - return value - - if isinstance(value, bool): - res_type = Boolean - elif isinstance(value, int): - res_type = Int32 - elif isinstance(value, float): - res_type = Float32 - elif isinstance(value, ArithValue): - res_type = Numeric.from_mlir_type(value.type) - else: - raise ValueError( - f"unable to convert {value} in type {type(value)} to Numeric" - ) - return res_type(value) - - def __add__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __sub__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __mul__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.floordiv, promote_bool=True)( - self, other, loc=loc, ip=ip - ) - - def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.truediv, promote_bool=True)( - self, other, loc=loc, ip=ip - ) - - def __mod__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip) - - def __radd__(self, other, *, loc=None, ip=None) -> "Numeric": - return self.__add__(other, loc=loc, ip=ip) - - def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.sub, promote_bool=True, flip=True)( - self, other, loc=loc, ip=ip - ) - - def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric": - return self.__mul__(other, loc=loc, ip=ip) - - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.floordiv, promote_bool=True, flip=True)( - self, other, loc=loc, ip=ip - ) - - def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.truediv, promote_bool=True, flip=True)( - self, other, loc=loc, ip=ip - ) - - def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.mod, promote_bool=True, flip=True)( - self, other, loc=loc, ip=ip - ) - - def __eq__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore - - def __ne__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore - - def __lt__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore - - def __le__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore - - def __gt__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore - - def __ge__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore - - def __pow__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore - - def __c_pointers__(self): - raise ValueError( - f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}" - ) - - def __get_mlir_types__(self): - return [type(self).mlir_type] - - @staticmethod - def from_mlir_type(mlir_type): - type_map = { - T.bool(): Boolean, - T.f64(): Float64, - T.f32(): Float32, - T.tf32(): TFloat32, - T.f16(): Float16, - T.bf16(): BFloat16, - T.i(128): Int128, - T.i64(): Int64, - T.i32(): Int32, - T.i16(): Int16, - T.i8(): Int8, - T.si(128): Int128, - T.si64(): Int64, - T.si32(): Int32, - T.si16(): Int16, - T.si8(): Int8, - T.ui(128): Uint128, - T.ui64(): Uint64, - T.ui32(): Uint32, - T.ui16(): Uint16, - T.ui8(): Uint8, - T.f8E5M2(): Float8E5M2, - T.f8E4M3(): Float8E4M3, - T.f8E4M3FN(): Float8E4M3FN, - T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ, - T.f4E2M1FN(): Float4E2M1FN, - T.f6E2M3FN(): Float6E2M3FN, - T.f6E3M2FN(): Float6E3M2FN, - T.f8E8M0FNU(): Float8E8M0FNU, - } - if mlir_type not in type_map: - raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}") - return type_map[mlir_type] - - -def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric: - """Convert a Python primitive value to a Numeric type. - - :param obj: Python primitive value to convert - :type obj: Union[bool, int, float] - :return: The converted Numeric object - :rtype: Numeric - - Example:: - - .. code-block:: python - - x = as_numeric(5) # Converts to Int32 - y = as_numeric(3.14) # Converts to Float32 - z = as_numeric(True) # Converts to Boolean - """ - if isinstance(obj, Numeric): - return obj - return Numeric._from_python_value(obj) - - -class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True): - """A class representing integer values with specific width and signedness. - - This class provides functionality to create and manipulate integer values with - configurable width and signedness. It supports conversion from various input types - including Python scalars, MLIR Values, and other numeric types. - - :param x: The input value to convert to this integer type - :type x: Union[bool, int, float, ir.Value, Integer, Float] - - :return: A new Integer instance with the converted value - :rtype: Integer - - :raises AssertionError: If the type's numpy_dtype is None - :raises NotImplementedError: If converting between different Integer types - :raises ValueError: If the input type is not supported for conversion - :raises OverflowError: If converting float infinity to integer - - Type conversion behavior: - - * Python scalars (bool, int, float): - * Converted through numpy dtype casting - * NaN and infinity values are rejected - * Example: Int8(256) -> -256 (overflow behavior) - - * MLIR Value with IntegerType: - * Width differences handled by signless to signed/unsigned conversion - * Example: i8 -> i8/ui8 depending on target type - - * MLIR Value with FloatType: - * Uses MLIR float-to-int conversion - * NaN and infinity values is undefined behavior - * Example: f32 -> i32/ui32 depending on target type - - * Integer: - * Uses MLIR float-to-int conversion or numpy dtype casting - * Example: Int32(Int32(5)) => 5 - - * Float: - * Uses MLIR float-to-int conversion - * Example: Int32(Float(5.7)) -> 5 - - Example usage: - - .. code-block:: python - - x = Int32(5) # From integer - y = Int32(True) # From boolean - z = Int32(3.7) # From float (truncates) - w = Int32(x) # From same Integer type - c5 = arith.constant(5, T.i32()) - a = Int32(c5) # Treat c5 as int32 bitwise - """ - - def __init__(self, x, *, loc=None, ip=None): - ty = type(self) - - if isinstance(x, (bool, int, float)): - # Add check for NaN before numpy conversion - if isinstance(x, float): - if np.isnan(x): - raise ValueError("Cannot convert float NaN to integer") - elif np.isinf(x): - raise OverflowError("Cannot convert float infinity to integer") - - np_dtype = ty.numpy_dtype - assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" - x_val = int(np.array(x).astype(np_dtype)) - elif type(x) == ty: - x_val = x.value - elif isinstance(x, ir.Value): # type: ignore - x_val = x - if isinstance(x.type, ir.IntegerType): # type: ignore - if x.type.width != ty.width: - # signless -> (u)int - x_val = _arith_signless_to_int(x, ty) - elif isinstance(x.type, ir.FloatType): # type: ignore - # float -> (u)int - x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip) - elif isinstance(x, Integer): - if isinstance(x.value, ir.Value): - x_val = arith_helper.int_to_int(x.ir_value(), ty) - else: - # For non-MLIR values, use numpy casting - src_val = np.array(x.value, dtype=type(x).numpy_dtype) - x_val = int(src_val.astype(ty.numpy_dtype)) - elif isinstance(x, Float): - # float -> int is handled by Integer.__init__ recursively - Integer.__init__(self, x.value) - return - else: - raise DSLRuntimeError(f"{x} to integer conversion is not supported") - - super().__init__(x_val) - - def __invert__(self, *, loc=None, ip=None): - res_type = type(self) - return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip)) - - def __lshift__(self, other, *, loc=None, ip=None): - return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip) - - def __rlshift__(self, other, *, loc=None, ip=None): - other_ = as_numeric(other) - if not isinstance(other_, Integer): - raise ValueError(f"Cannot left shift {other_} with {self}") - return other_.__lshift__(self, loc=loc, ip=ip) - - def __rshift__(self, other, *, loc=None, ip=None): - return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip) - - def __rrshift__(self, other, *, loc=None, ip=None): - other_ = as_numeric(other) - if not isinstance(other_, Integer): - raise ValueError(f"Cannot right shift {other_} with {self}") - return other_.__rshift__(self, loc=loc, ip=ip) - - def __and__(self, other, *, loc=None, ip=None): - return _binary_op(operator.and_)(self, other, loc=loc, ip=ip) - - def __rand__(self, other, *, loc=None, ip=None): - return self.__and__(other, loc=loc, ip=ip) - - def __or__(self, other, *, loc=None, ip=None): - return _binary_op(operator.or_)(self, other, loc=loc, ip=ip) - - def __ror__(self, other, *, loc=None, ip=None): - return self.__or__(other, loc=loc, ip=ip) - - def __xor__(self, other, *, loc=None, ip=None): - return _binary_op(operator.xor)(self, other, loc=loc, ip=ip) - - def __rxor__(self, other, *, loc=None, ip=None): - return self.__xor__(other, loc=loc, ip=ip) - - -class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True): - """A class representing floating-point values. - - :param x: The input value to convert to this float type. - :type x: Union[bool, int, float, ir.Value, Integer, Float] - - Type conversion behavior: - - 1. Python scalars (bool, int, float): - - Converted through numpy dtype casting - - Example: Float32(1.7) -> 1.7 - - 2. MLIR Value with FloatType: - - If width differs: converts between float types - - Example: f16 -> f32 - - 3. MLIR Value with IntegerType: - - Not supported, raises ValueError - - 4. Integer: - - Converts using MLIR int-to-float operation - - Example: Float32(Int32(5)) -> 5.0 - - 5. Float: - - Direct conversion between float types - - Example: Float32(Float32(1.5)) -> 1.5 - - .. note:: - The following narrow precision types are only supported in device code: - - 8-bit float types: - - Float8E5M2 - - Float8E4M3 - - Float8E4M3FN - - Float8E8M0FNU - - Float8E4M3B11FNUZ - - 6-bit float types: - - Float6E3M2FN - - Float6E2M3FN - - 4-bit float types: - - Float4E2M1FN - - Narrow precision types and special floating-point formats support matrix on device: - - :raises AssertionError: If the type's numpy_dtype is None - :raises ValueError: If conversion from the input type is not supported - """ - - def __init__(self, x, *, loc=None, ip=None): - ty = type(self) - - if isinstance(x, (bool, int, float)): # type: ignore - # Why we need to convert x to with numpy? - # np_dtype = ty.numpy_dtype - # assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" - # x = float(np.array(x).astype(np_dtype)) - super().__init__(float(x)) - elif isinstance(x, ir.Value): # type: ignore - if isinstance(x.type, ir.IntegerType): # type: ignore - raise DSLRuntimeError("signless to float conversion is not implemented") - elif isinstance(x.type, ir.FloatType): # type: ignore - if x.type != ty.mlir_type: - x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip) - super().__init__(x) - elif isinstance(x, Integer): - if isinstance(x.value, ir.Value): # type: ignore - x = arith_helper.itofp( - x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip - ) - else: - x = float(x.value) - super().__init__(x) - elif isinstance(x, Float): - Float.__init__(self, x.value) - else: - raise DSLRuntimeError(f"{x} to Float conversion is not supported") - - -class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool): - """Boolean type representation in the DSL. - - This class represents boolean values in the DSL, with a width of 1 bit. - It supports conversion from various types to boolean values. - - :param a: Value to convert to Boolean - :type a: Union[bool, int, float, "Value", Numeric] - :param loc: Source location information, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: Optional[InsertionPoint], optional - :raises DSLRuntimeError: If the input value cannot be converted to Boolean - - Conversion rules: - - 1. Python bool/int/float: - - Converted using Python's bool() function - - Example: Boolean(1) -> True, Boolean(0) -> False - - 2. Numeric: - - Uses the Numeric.value to construct Boolean recursively - - 3. MLIR Value with IntegerType: - - If width is 1: Direct assignment - - Otherwise: Compares with 0 using arith.cmpi - - 4. MLIR Value with FloatType: - - Compares with 0.0 using arith.cmpf - - Uses unordered comparison to handle NaN values - """ - - def __init__( - self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None - ): - value = None - if isinstance(a, (bool, int, float)): - value = bool(a) - elif isinstance(a, Numeric): - Boolean.__init__(self, a.value, loc=loc, ip=ip) - return - elif isinstance(a, ArithValue): - if a.type == T.bool(): - value = a - else: - value = a != arith_helper.const(0, a.type, loc=loc, ip=ip) - if value is None: - raise DSLRuntimeError(f"Cannot convert {a} to Boolean") - super().__init__(value, loc=loc, ip=ip) - self._value_int8 = None - - def ir_value_int8(self, *, loc=None, ip=None): - """ - Returns int8 ir value of Boolean. - When we need to store Boolean tensor element, use ir_value_int8(). - - :param loc: Source location information, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: Optional[InsertionPoint], optional - :return: The int8 value of this Boolean - :rtype: ir.Value - """ - if self._value_int8 is not None: - return self._value_int8 - self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value() - return self._value_int8 - - def __neg__(self, *, loc=None, ip=None): - """Negation operator is not supported for boolean type. - - :param loc: Source location information, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: Optional[InsertionPoint], optional - :raises TypeError: Always raises this error as negation is not supported - """ - raise TypeError("Negation, the operator `-` is not supported for boolean type") - - -class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ... - - -class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ... - - -class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ... - - -class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ... - - -class Int128( - Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128) -): ... - - -class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ... - - -class Uint16( - Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16 -): ... - - -class Uint32( - Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32 -): ... - - -class Uint64( - Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64 -): ... - - -class Uint128( - Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128) -): ... - - -class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64): - def __c_pointers__(self): - if not isinstance(self.value, float): - raise ValueError("only float is supported") - - return [ - ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p) - ] - - -class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32): - @staticmethod - def _get_c_pointer(value: float): - return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p) - - def __c_pointers__(self): - if not isinstance(self.value, float): - raise ValueError("only float is supported") - - return [Float32._get_c_pointer(self.value)] - - -class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32): - def __c_pointers__(self): - if not isinstance(self.value, float): - raise ValueError("only float is supported") - return [Float32._get_c_pointer(self.value)] - - -class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16): - @staticmethod - def _get_c_pointer(value: float): - # Convert float to float16 binary representation - # First convert to numpy float16 to handle the conversion - f16_val = np.float16(value) - # Get the raw bits as a 16-bit integer - bits = f16_val.view(np.uint16) - # Create a short (16-bit int) with those bits - c_val = ctypes.c_short(bits) - return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) - - def __c_pointers__(self): - if not isinstance(self.value, float): - raise ValueError("only float is supported") - return [Float16._get_c_pointer(self.value)] - - -class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16): - def __c_pointers__(self): - if not isinstance(self.value, float): - raise ValueError("only float is supported") - - return Float.__c_pointers__(self) - - -class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ... - - -class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ... - - -class Float8E4M3B11FNUZ( - Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ -): ... - - - -# Added missing float types -class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ... - - -class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ... - - -class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ... - - -class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ... - - -class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ... - - -_unsupported_dst_float_types = [ - Float8E4M3, - Float8E4M3B11FNUZ, - Float4E2M1FN, - Float6E3M2FN, - Float6E2M3FN, -] - - -ALL_DTYPES = { - Int8, - Int16, - Int32, - Int64, - Int128, - Uint8, - Uint16, - Uint32, - Uint64, - Uint128, - BFloat16, - Float16, - Float32, - TFloat32, - Float64, - Float8E5M2, - Float8E4M3, - Float8E4M3FN, - Float8E8M0FNU, - Float8E4M3B11FNUZ, - Float4E2M1FN, - Float6E2M3FN, - Float6E3M2FN, -} -__STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES} - - -def dtype(dtype_) -> Type[Numeric]: - t = None - if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__): - t = __STR_TO_DTYPE__[dtype_] - else: - raise TypeError(f"can't interpret {dtype_} as data type") - - return t - - -############################################################## -# Tensor -############################################################## - - -class TensorMeta(DslType): - _element_type = Any - _shape = Any - - """ - Examples: - >>> Tensor[Int32, (3,)] - >>> Tensor[Float32, (3, 4)] - >>> T = TypeVar("T") - >>> Tensor[T, (3, 4, 5)] - """ - - def __new__(cls, name, bases, attrs, element_type=Any, shape=Any): - new_cls = super().__new__(cls, name, bases, attrs) - new_cls._element_type = element_type - new_cls._shape = shape - return new_cls - - -# Generic type -TY = TypeVar("TY") - - -class Constexpr(Generic[TY]): - """Value is passed and computed by python interpreter""" - - pass - - -class align: - def __init__(self, value: int): - if value <= 0 or (value & (value - 1)) != 0: - raise DSLRuntimeError("expects align be power of 2 as positive value") - self._value = value - - def __str__(self): - return f"align({self._value})" - - -class PointerMeta(DslType): - def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)): - new_cls = super().__new__( - cls, - name, - bases, - attrs, - mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get( - value_type.mlir_type, getattr(ir, "Attribute").parse("0") - ), - ) - new_cls._value_type = value_type - new_cls._align = align_ - return new_cls - - def __eq__(cls, other): - if not isinstance(other, PointerMeta): - return False - return ( - cls._value_type == other._value_type - and cls._align._value == other._align._value - ) # Compare alignment values - - def __hash__(cls): - return hash((cls._value_type, cls._align._value)) # Hash alignment value - - def __getitem__(cls, params) -> Type["Pointer"]: - value_type, align_ = params - - if not isinstance(align_, align): - raise DSLRuntimeError(f"expects align but got {align_}") - - # Create new class with proper name and parameters - new_cls = type( - f"Pointer[{value_type.__name__}, {align_}]", - (Pointer,), - {}, - value_type=value_type, - align_=align_, # Pass alignment to __new__ - ) - return new_cls - - def __str__(cls): - return f"ptr<{cls._value_type}, {cls._align}>" - - -class Pointer(metaclass=PointerMeta): - """ - A pointer to a memory location. - - Examples: - - def foo(a : Pointer[Int32, align=8]): - ... - - """ - - def __init__(self, value): - self.value = value - - def __str__(self): - return f"{self.value} : {type(self)}" - - -class IRConst(Generic[TY]): - """Value is passed as MLIR constant value for (arith.constant).""" - - def __init__(self, ty: TY): - self.ty = ty - - -class IRValue(Generic[TY]): - """Value is passed as MLIR dynamic value.""" - - def __init__(self, ty: TY): - self.ty = ty - - -class IRVariadic: - """ - A helper class to pass a variadic number of arguments to a function. - """ - - def __init__(self, operands): - """ - Create a list of variadic operands. `operands` must be dynamic values. - """ - self.operands = operands - - def block_arg_types(self): - """ - Return the list of block args types. - """ - return [operand.type for operand in self.operands] - - def set_func_args(self, block_args): - """ - This function is called after entering a function. `block_args` are the - block arguments that correspond to the passed operands. Derived classes - may implement this function to provide convenience getters for block - arguments. - """ - pass - - def __len__(self): - """ - Return the length of variadic operands. - """ - return len(self.operands) - - -class FuncArgWithAttr(IRValue): - """ - This derived class is specifically for func op arg with attr - """ - - def __init__(self, ty, attr_name, attr_ty, attr_value=None): - super().__init__(ty) - assert attr_name is not None and ( - attr_ty is not None or attr_value is not None - ), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr" - self.attr_name = attr_name - self.attr_ty = attr_ty - self.attr_value = attr_value - - - -def implicitDowncastNumericType(value): - if isinstance(value, Numeric): - return value.ir_value() - return value - - -__all__ = [ - "DslType", - "Numeric", - "NumericMeta", - "IntegerMeta", - "FloatMeta", - "Boolean", - "Integer", - "Int16", - "Int32", - "Int64", - "Int128", - "Int8", - "Uint8", - "Uint16", - "Uint32", - "Uint64", - "Uint128", - "Float", - "Float16", - "BFloat16", - "TFloat32", - "Float32", - "Float64", - "Float8E5M2", - "Float8E4M3", - "Float8E4M3FN", - "Float8E4M3B11FNUZ", - "Float8E4M3", - "Float8E8M0FNU", - "Float4E2M1FN", - "Float6E2M3FN", - "Float6E3M2FN", - "as_numeric", - "align", - "Pointer", - "dtype", - "Constexpr", - "IRConst", - "IRValue", - "IRVariadic", - "implicitDowncastNumericType", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py deleted file mode 100644 index c4bfb2b7d91ee72b04a89de59e7dfbdec2be646c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from . import stacktrace -from . import logger -from . import timer -__all__ = [ - "logger", - "timer", - "stacktrace", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py deleted file mode 100644 index d4e4b4edf359ec86b6b5806cb0b2296f9cb918f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides logging helper functions -""" - -import logging - -logger = None - - -def log(): - return logger - - -def setup_log( - name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1 -): - """Set up and configure a logger with console and/or file handlers. - - :param name: Name of the logger to create - :type name: str - :param log_to_console: Whether to enable logging to console, defaults to False - :type log_to_console: bool, optional - :param log_to_file: Whether to enable logging to file, defaults to False - :type log_to_file: bool, optional - :param log_file_path: Path to the log file, required if log_to_file is True - :type log_file_path: str, optional - :param log_level: Logging level to set, defaults to 1 - :type log_level: int, optional - :raises ValueError: If log_to_file is True but log_file_path is not provided - :return: Configured logger instance - :rtype: logging.Logger - """ - # Create a custom logger - global logger - logger = logging.getLogger(name) - if log_to_console or log_to_file: - logger.setLevel(log_level) - else: - # Makes sure logging is OFF - logger.setLevel(logging.CRITICAL + 1) - - # Clear existing handlers to prevent duplicate logs - if logger.hasHandlers(): - logger.handlers.clear() - - # Define formatter - formatter = logging.Formatter( - f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s" - ) - - # Add console handler if enabled - if log_to_console: - console_handler = logging.StreamHandler() - console_handler.setLevel(log_level) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - # Add file handler if enabled - if log_to_file: - if not log_file_path: - raise ValueError("log_file_path must be provided when enable_file is True") - file_handler = logging.FileHandler(log_file_path) - file_handler.setLevel(log_level) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - return logger - - -logger = setup_log("generic") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py deleted file mode 100644 index d2091098c173e8a941ed7958802dfbdee24199bc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py +++ /dev/null @@ -1,165 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" - This module provides stacktrace helper functions -""" - -import os -import re - - -def walk_to_top_module(start_path): - """ - Walk up from the start_path to find the top-level Python module. - - :param start_path: The path to start from. - :return: The path of the top-level module. - """ - current_path = start_path - - while True: - # Check if we are at the root directory - if os.path.dirname(current_path) == current_path: - break - - # Check for __init__.py - init_file_path = os.path.join(current_path, "__init__.py") - if os.path.isfile(init_file_path): - # If __init__.py exists, move up one level - current_path = os.path.dirname(current_path) - else: - # If no __init__.py, we are not in a module; stop - break - - # If we reached the root without finding a module, return None - if os.path.dirname(current_path) == current_path and not os.path.isfile( - os.path.join(current_path, "__init__.py") - ): - return None - - # Return the path of the top-level module - return current_path - - -def _filter_internal_frames(traceback, internal_path): - """ - Filter out stack frames from the traceback that belong to the specified module path. - - This function removes stack frames from the traceback whose file paths start with - the given prefix_path, effectively hiding internal implementation details from - the error traceback shown to users. - """ - iter_prev = None - iter_tb = traceback - while iter_tb is not None: - if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith( - internal_path - ): - if iter_tb.tb_next: - if iter_prev: - iter_prev.tb_next = iter_tb.tb_next - else: - traceback = iter_tb.tb_next - else: - iter_prev = iter_tb - iter_tb = iter_tb.tb_next - return traceback - - -_generated_function_names = re.compile( - r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$" -) - - -def _filter_duplicated_frames(traceback): - """ - Filter out duplicated stack frames from the traceback. - The function filters out consecutive frames that are in the same file and have the same line number. - In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame. - """ - iter_prev = None - iter_tb = traceback - while iter_tb is not None: - skip_current = False - skip_next = False - if iter_tb.tb_next: - current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename) - next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename) - # if in the same file, check if the line number is the same - if current_filename == next_filename: - current_lineno = iter_tb.tb_lineno - next_lineno = iter_tb.tb_next.tb_lineno - if current_lineno == next_lineno: - # Same file and line number, check name, if current is generated, skip current, otherwise skip next - name = iter_tb.tb_frame.f_code.co_name - is_generated = bool(_generated_function_names.match(name)) - if is_generated: - # Skip current - skip_current = True - else: - # Skip next if it's generated, otherwise keep both - next_name = iter_tb.tb_next.tb_frame.f_code.co_name - skip_next = bool(_generated_function_names.match(next_name)) - if skip_current: - if iter_prev: - iter_prev.tb_next = iter_tb.tb_next - else: - traceback = iter_tb.tb_next - elif skip_next: - # if next is last frame, don't skip - if iter_tb.tb_next.tb_next: - iter_tb.tb_next = iter_tb.tb_next.tb_next - iter_prev = iter_tb - else: - iter_prev = iter_tb - iter_tb = iter_tb.tb_next - - return traceback - - -def filter_stackframe(traceback, prefix_path): - """ - Filter out stack frames from the traceback that belong to the specified module path. - - This function removes stack frames from the traceback whose file paths start with - the given prefix_path, effectively hiding internal implementation details from - the error traceback shown to users. - - :param traceback: The traceback object to filter. - :param prefix_path: The path prefix to filter out from the traceback. - :return: The filtered traceback with internal frames removed. - """ - # Step 1: filter internal frames - traceback = _filter_internal_frames(traceback, prefix_path) - - # Step 2: consolidate duplicated frames - return _filter_duplicated_frames(traceback) - - -def filter_exception(value, module_dir): - """ - Filter out internal implementation details from exception traceback. - - This function recursively processes an exception and its cause chain, - removing stack frames that belong to the specified module directory. - This helps to present cleaner error messages to users by hiding - implementation details. - - :param value: The exception object to filter. - :param module_dir: The module directory path to filter out from tracebacks. - :return: The filtered exception with internal frames removed. - """ - if hasattr(value, "__cause__") and value.__cause__: - filter_exception(value.__cause__, module_dir) - - if hasattr(value, "__traceback__"): - filter_stackframe(value.__traceback__, module_dir) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py deleted file mode 100644 index f41d3f7410c0227ff1b1f8df4b8ce14557cf649b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides a timing helper functions -""" -from functools import wraps - -from .logger import log - - -# TODO: revisit this part when mlir timing manager is ready for pybind. -def timer(*dargs, **kwargs): - enable = kwargs.get("enable", True) - - def decorator(func): - @wraps(func) - def func_wrapper(*args, **kwargs): - if not enable: - return func(*args, **kwargs) - from time import time - - start = time() - result = func(*args, **kwargs) - end = time() - - # Convert time from seconds to us - spend_us = (end - start) * 1e6 - - # Determine the function type and format the log message - if hasattr(func, "__name__"): - func_name = func.__name__ - log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs" - elif "CFunctionType" in str(type(func)): - log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs" - else: - log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs" - - log().info(log_message) - - return result - - return func_wrapper - - if len(dargs) == 1 and callable(dargs[0]): - return decorator(dargs[0]) - else: - return decorator diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py deleted file mode 100644 index f2c7ed2607675990ad9579fa06b25935b2ccb46e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .cutlass_dsl import ( - Constexpr, - as_numeric, - min, - max, - and_, - or_, - all_, - any_, - not_, - all_, - any_, - select_, - # Control-flow without AST pre-processor - if_generate, - for_generate, - LoopUnroll, - while_generate, - yield_out, - # Control-flow with AST pre-processor - range_constexpr, - range_dynamic, - const_expr, - dynamic_expr, - # Data types - dtype, # Provides conversions to types inheriting from NumericType - DSLRuntimeError, - JitArgAdapterRegistry, - # Construction utilities for user-defined classes - extract_mlir_values, - new_from_mlir_values, -) - -from .cute.typing import * - -# Utilities not belonging to CuTe -from . import utils as utils - -# Used as internal symbol -from . import cutlass_dsl as _dsl - -# Aliases -LaunchConfig = _dsl.BaseDSL.LaunchConfig -register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter -gpu = _dsl.cutlass_gpu -cuda = _dsl.cuda_helpers - -CACHE_FILE = "compiled_cache.db" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py deleted file mode 100644 index 8702ed9163837925057b48f9aafd11cffbb26a7e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -# Use the auto-generated enum AddressSpace -from cutlass._mlir.dialects.cute import AddressSpace - -# Explicitly import types that might be directly used by other modules. -# This is a fix for using Sphinx to generate documentation -# Because Sphinx processes each module in isolation, it won't be able to rely -# on re-exported symbols via wildcard imports (from .typing import *) in the -# same way that Python does at runtime. -from .typing import ( - Shape, - Stride, - IntTuple, - Coord, - Tile, - XTuple, - Tiler, - Layout, - Pointer, - Tensor, -) - -# Import everything else -from .typing import * - -from .core import ( - assume, - is_integer, - is_int_tuple, - is_static, - size, - has_underscore, - slice_, - make_ptr, - make_layout, - recast_layout, - make_fragment_like, - depth, - rank, - flatten_to_tuple, - flatten, - unflatten, - product, - product_like, - shape, - size_in_bytes, - make_identity_layout, - make_ordered_layout, - make_composed_layout, - make_layout_tv, - make_swizzle, - recast_ptr, - make_tensor, - make_identity_tensor, - make_fragment, - recast_tensor, - get, - select, - front, - is_major, - leading_dim, - find, - find_if, - coalesce, - group_modes, - cosize, - dice, - product_each, - prepend, - append, - prepend_ones, - append_ones, - ceil_div, - slice_and_offset, - crd2idx, - domain_offset, - elem_less, - transform_leaf, - filter_zeros, - filter, - tile_to_shape, - shape_div, - composition, - complement, - right_inverse, - left_inverse, - max_common_layout, - max_common_vector, - logical_product, - zipped_product, - tiled_product, - flat_product, - raked_product, - blocked_product, - flat_divide, - logical_divide, - zipped_divide, - tiled_divide, - local_partition, - local_tile, - printf, - print_tensor, - # tiled mma/tiled copy - make_mma_atom, - make_tiled_mma, - make_copy_atom, - make_tiled_copy_tv, - make_tiled_copy, - make_tiled_copy_S, - make_tiled_copy_D, - make_tiled_copy_A, - make_tiled_copy_B, - make_tiled_copy_C, - make_tiled_copy_C_atom, - basic_copy, - basic_copy_if, - autovec_copy, - copy, - copy_atom_call, - gemm, - # Wrapper classes - ComposedLayout, - Swizzle, - E, - Atom, - MmaAtom, - CopyAtom, - TiledCopy, - TiledMma, - TensorSSA, - ReductionOp, - full, - full_like, - empty_like, - ones_like, - zeros_like, - where, - any_, - all_, - # User defined struct - struct, - pretty_str, - make_layout_image_mask, - repeat_like, - round_up, - is_congruent, - is_weakly_congruent, - ScaledBasis, - get_divisibility, - Ratio, -) - -from . import arch -from . import nvgpu -from . import testing -from . import runtime - -# Export all math ops without "math." -from .math import * - -# Used as internal symbol -from .. import cutlass_dsl as _dsl - -# Aliases -jit = _dsl.CuTeDSL.jit -kernel = _dsl.CuTeDSL.kernel -register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter -compile = _dsl.compile - -# Explicitly export all symbols for documentation generation -__all__ = [ - # Core types - "AddressSpace", - "Tensor", - "Layout", - "ComposedLayout", - "Swizzle", - "E", - "Atom", - "MmaAtom", - "CopyAtom", - "TiledCopy", - "TiledMma", - "TensorSSA", - # Basic utility functions - "assume", - "is_integer", - "is_int_tuple", - "is_static", - "size", - "has_underscore", - "slice_", - "depth", - "rank", - "shape", - "printf", - "print_tensor", - "pretty_str", - # Layout functions - "make_layout", - "recast_layout", - "make_identity_layout", - "make_ordered_layout", - "make_composed_layout", - "make_layout_tv", - "make_layout_image_mask", - # Tensor functions - "make_ptr", - "make_tensor", - "make_identity_tensor", - "make_fragment", - "make_fragment_like", - "recast_ptr", - "recast_tensor", - # Tensor manipulation - "get", - "select", - "front", - "is_major", - "leading_dim", - "find", - "find_if", - "coalesce", - "group_modes", - "cosize", - "size_in_bytes", - # Tuple operations - "flatten_to_tuple", - "flatten", - "product", - "product_like", - "product_each", - "prepend", - "append", - "prepend_ones", - "append_ones", - # Math operations - "ceil_div", - "round_up", - # Layout operations - "slice_and_offset", - "crd2idx", - "domain_offset", - "elem_less", - "filter_zeros", - "filter", - "tile_to_shape", - "shape_div", - "dice", - # Layout algebra - "composition", - "complement", - "right_inverse", - "left_inverse", - "max_common_layout", - "max_common_vector", - "is_congruent", - "is_weakly_congruent", - # Product operations - "logical_product", - "zipped_product", - "tiled_product", - "flat_product", - "raked_product", - "blocked_product", - # Division operations - "flat_divide", - "logical_divide", - "zipped_divide", - "tiled_divide", - "local_partition", - "local_tile", - # MMA and Copy operations - "make_mma_atom", - "make_tiled_mma", - "make_copy_atom", - "make_tiled_copy_tv", - "make_tiled_copy", - "make_tiled_copy_C_atom", - "basic_copy", - "basic_copy_if", - "autovec_copy", - "copy", - "copy_atom_call", - "gemm", - # Tensor creation - "full", - "full_like", - "empty_like", - "ones_like", - "zeros_like", - "where", - "any_", - "all_", - "repeat_like", - "ScaledBasis", - # User defined struct - "struct", - # Modules - "arch", - "nvgpu", - "testing", - "runtime", - # Decorators and code generation - "jit", - "kernel", - "register_jit_arg_adapter", - "compile", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py deleted file mode 100644 index 01198215f74b07f224b1d5e53ff37075775bb201..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .elect import * -from .mbar import * -from .nvvm_wrappers import * -from .smem import * -from .tmem import * - -# __all__ is required here for documentation generation -__all__ = [ - # - # elect.py - # - "make_warp_uniform", - "elect_one", - # - # mbar.py - # - "mbarrier_init", - "mbarrier_init_fence", - "mbarrier_arrive_and_expect_tx", - "mbarrier_expect_tx", - "mbarrier_wait", - "mbarrier_try_wait", - "mbarrier_conditional_try_wait", - "mbarrier_arrive", - # - # nvvm_wrappers.py - # - "lane_idx", - "warp_idx", - "thread_idx", - "block_dim", - "block_idx", - "grid_dim", - "cluster_idx", - "cluster_dim", - "block_in_cluster_idx", - "block_in_cluster_dim", - "block_idx_in_cluster", - "shuffle_sync", - "shuffle_sync_up", - "shuffle_sync_down", - "shuffle_sync_bfly", - "barrier", - "barrier_arrive", - "sync_threads", - "sync_warp", - "fence_acq_rel_cta", - "fence_acq_rel_cluster", - "fence_acq_rel_gpu", - "fence_acq_rel_sys", - "cp_async_commit_group", - "cp_async_wait_group", - "cp_async_bulk_commit_group", - "cp_async_bulk_wait_group", - "cluster_wait", - "cluster_arrive", - "cluster_arrive_relaxed", - "fence_proxy", - "vote_ballot_sync", - "popc", - "fence_view_async_tmem_load", - "fence_view_async_tmem_store", - "warpgroup_reg_alloc", - "warpgroup_reg_dealloc", - "fma_packed_f32x2", - "mul_packed_f32x2", - "add_packed_f32x2", - "fmax", - "rcp_approx", - "exp2", - # Constants - "WARP_SIZE", - # Forward from auto-generated nvvm python - "ProxyKind", - "SharedSpace", - "RoundingModeKind", - # - # smem.py - # - "alloc_smem", - "get_dyn_smem", - "get_dyn_smem_size", - # - # tmem.py - # - "retrieve_tmem_ptr", - "alloc_tmem", - "relinquish_tmem_alloc_permit", - "dealloc_tmem", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py deleted file mode 100644 index ead552afab7de50a62f95eee7b4d8a2d9b4dfca9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op - -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import nvvm, scf -from cutlass._mlir import ir - -from ..typing import Int, Int32 -from ...impl_utils import check_value_in - - -@dsl_user_op -def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32: - """ - Creates a warp-uniform value from the given integer input. - - :param value: The integer to make warp uniform. - :type value: Int - :return: The warp-uniform value equal to the input. - :rtype: Int32 - """ - return Int32( - _cute_nvgpu_ir.arch_make_warp_uniform( - Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - ) - - -class IfOpRegion: - """ - A context manager for if Op. - Automatically inserts `scf.yield([])` when exiting the context. - """ - - def __init__(self, block, *, loc=None, ip=None): - self.block = block - self.insert_point = ir.InsertionPoint(self.block) - self.loc = loc - self.ip = ip - - def __enter__(self): - self.insert_point.__enter__() - return self.block.arguments - - def __exit__(self, exc_type, exc_value, traceback): - scf.yield_([], loc=self.loc, ip=self.ip) - self.insert_point.__exit__(exc_type, exc_value, traceback) - - -@dsl_user_op -def elect_one(*, loc=None, ip=None) -> IfOpRegion: - """ - Elects one thread within a warp. - - .. code-block:: python - - with elect_one(): - # Only one thread in the warp executes the code in this context - pass - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - is_thread_leader = nvvm.elect_sync(T.bool()) - if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip) - return IfOpRegion(if_op.then_block, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py deleted file mode 100644 index 80cb7b0b5fc6e226a39d68197382cbde2e32861d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ /dev/null @@ -1,349 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. -from typing import Optional - -from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op - -from cutlass._mlir.dialects import nvvm -from cutlass._mlir import ir - -from ..typing import Pointer, Int, Boolean, Int32 -from ...impl_utils import check_value_in - - -#################################################################################################### -# -# Mbarrier management utilities -# -#################################################################################################### - - -@dsl_user_op -def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: - """ - Initializes a mbarrier with the specified thread arrival count. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param cnt: The arrival count of the mbarrier - :type cnt: Int - """ - nvvm.mbarrier_init_shared( - mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - - -@dsl_user_op -def mbarrier_init_fence(*, loc=None, ip=None) -> None: - """ - A fence operation that applies to the mbarrier initializations. - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - nvvm.fence_mbarrier_init(loc=loc, ip=ip) - - -@dsl_user_op -def mbarrier_arrive_and_expect_tx( - mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None -) -> None: - """ - Arrives on a mbarrier and expects a specified number of transaction bytes. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param bytes: The number of transaction bytes - :type bytes: Int - :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to - the mbarrier is converted to a remote address in the peer CTA's - SMEM. - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - mbar_llvm_ptr = mbar_ptr.llvm_ptr - if peer_cta_rank_in_cluster is not None: - mbar_llvm_ptr = nvvm.mapa_shared_cluster( - mbar_llvm_ptr.type, - mbar_llvm_ptr, - Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - space = nvvm.MBarrierSpaceKind.CLUSTER - else: - space = nvvm.MBarrierSpaceKind.CTA - - nvvm.mbarrier_txn( - mbar_llvm_ptr, - Int32(bytes).ir_value(loc=loc, ip=ip), - kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX, - space=space, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def mbarrier_expect_tx( - mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None -) -> None: - """ - Expects a specified number of transaction bytes without an arrive. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param bytes: The number of transaction bytes - :type bytes: Int - :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to - the mbarrier is converted to a remote address in the peer CTA's - SMEM. - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - mbar_llvm_ptr = mbar_ptr.llvm_ptr - if peer_cta_rank_in_cluster is not None: - mbar_llvm_ptr = nvvm.mapa( - mbar_llvm_ptr.type, - mbar_llvm_ptr, - Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - space = nvvm.MBarrierSpaceKind.CLUSTER - else: - space = nvvm.MBarrierSpaceKind.CTA - - nvvm.mbarrier_txn( - mbar_llvm_ptr, - Int32(bytes).ir_value(loc=loc, ip=ip), - kind=nvvm.MBarrierTxnKind.EXPECT_TX, - space=space, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: - """ - Waits on a mbarrier with a specified phase. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param phase: The phase to wait for (either 0 or 1) - :type phase: Int - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - timeout_ns = 10000000 - # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX - # The timeout in ns only applies to the latter and this call is truly blocking - nvvm.mbarrier_try_wait_parity_shared( - mbar_ptr.llvm_ptr, - Int32(phase).ir_value(loc=loc, ip=ip), - Int32(timeout_ns).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean: - """ - Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param phase: The phase to wait for (either 0 or 1) - :type phase: Int - :return: A boolean value indicating whether the wait operation was successful - :rtype: Boolean - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - return Boolean( - nvvm.mbarrier_wait_parity( - T.bool(), - mbar_ptr.llvm_ptr, - Int32(phase).ir_value(loc=loc, ip=ip), - nvvm.MBarrierWaitKind.TRY, - loc=loc, - ip=ip, - ) - ) - - -@dsl_user_op -def mbarrier_conditional_try_wait( - cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None -) -> Boolean: - """ - Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. - - :param cond: A boolean predicate - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param phase: The phase to wait for (either 0 or 1) - :type phase: Int - :return: A boolean value indicating whether the wait operation was successful - :rtype: Boolean - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - return if_generate( - cond, - lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), - lambda: Boolean(True).ir_value(loc=loc, ip=ip), - None, - [Boolean], - ) - - -@dsl_user_op -def mbarrier_arrive( - mbar_ptr: Pointer, - peer_cta_rank_in_cluster: Optional[Int] = None, - *, - loc=None, - ip=None, -) -> None: - """ - Arrives on an mbarrier. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to - the mbarrier is converted to a remote address in the peer CTA's - SMEM. - """ - mbar_llvm_ptr = mbar_ptr.llvm_ptr - if peer_cta_rank_in_cluster is not None: - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - mbar_llvm_ptr = nvvm.mapa_shared_cluster( - mbar_llvm_ptr.type, - mbar_llvm_ptr, - Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - space = nvvm.MBarrierSpaceKind.CLUSTER - else: - space = nvvm.MBarrierSpaceKind.CTA - - nvvm.mbarrier_txn( - mbar_llvm_ptr, - Int32(1).ir_value(loc=loc, ip=ip), - kind=nvvm.MBarrierTxnKind.ARRIVE, - space=space, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: - """ - Arrives on an mbarrier for async load **without incrementing** the arrival count - (`cp.async.mbarrier.arrive.shared ..., noinc=1`). - Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same - as the math/epilogue warp(consumer). - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) - - mbar_llvm_ptr = mbar_ptr.llvm_ptr - nvvm.cp_async_mbarrier_arrive_shared( - mbar_llvm_ptr, - noinc=True, - loc=loc, - ip=ip, - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py deleted file mode 100644 index 69e3b8acb1fd0d1bc6615cd835235c0bbd62027b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ /dev/null @@ -1,681 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from functools import partial -from typing import Optional, Tuple, Union, Callable -from typing_extensions import deprecated - -from cutlass.cutlass_dsl import T, dsl_user_op - -from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, nvvm, vector - -# Forward nvvm enums -from cutlass._mlir.dialects.nvvm import ( - ProxyKind, - SharedSpace, - Tcgen05WaitKind, - SetMaxRegisterAction, - RoundingModeKind, -) - -from ..typing import ( - Int, - Boolean, - Int16, - Uint16, - Int32, - Uint32, - Int64, - Float32, - BFloat16, - Numeric, - as_numeric, -) - -WARP_SIZE = 32 -FULL_MASK = 0xFFFFFFFF - - -@dsl_user_op -def lane_idx(*, loc=None, ip=None) -> Int32: - """ - Returns the lane index of the current thread within the warp. - """ - return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip)) - - -@dsl_user_op -def warp_idx(*, loc=None, ip=None) -> Int32: - """ - Returns the warp index within a CTA. - """ - warp_size = 32 - tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)) - tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)) - tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)) - ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)) - ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)) - tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y - return tid // warp_size - - -@dsl_user_op -def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the thread index within a CTA. - """ - return ( - Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the number of threads in each dimension of the CTA. - """ - return ( - Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the CTA identifier within a grid. - """ - return ( - Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the number of CTAs in each dimension of the grid. - """ - return ( - Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the cluster identifier within a grid. - """ - return ( - Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the number of clusters in each dimension of the grid. - """ - return ( - Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the CTA index within a cluster across all dimensions. - """ - return ( - Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: - """ - Returns the dimensions of the cluster. - """ - return ( - Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)), - Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)), - ) - - -@dsl_user_op -def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: - """ - Returns the linearized identifier of the CTA within the cluster. - """ - return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip)) - - -@dsl_user_op -def shuffle_sync_op( - value: Numeric, - offset: Int, - mask: Int = FULL_MASK, - mask_and_clamp: Int = WARP_SIZE - 1, - kind: nvvm.ShflKind = nvvm.ShflKind.idx, - *, - loc=None, - ip=None, -) -> Numeric: - """ - Shuffles a value within the threads of a warp. - - :param value: The value to shuffle - :type value: Numeric - :param mask: A mask describing the threads participating in this operation - :type mask: Int - :param offset: A source lane or a source lane offset depending on kind - :type offset: Int - :param mask_and_clamp: An integer containing two packed values specifying a mask for logically - splitting warps into sub-segments and an upper bound for clamping the - source lane index. - :type mask_and_clamp: Int - :param kind: The kind of shuffle, can be idx, up, down, or bfly - :type kind: ShflKind - :return: The shuffled value - :rtype: Numeric - """ - if not isinstance(value, Numeric): - value = as_numeric(value) - if value.width > 64: - raise ValueError("shuffle_sync only supports values up to 64 bits") - - orig_type = type(value) - if value.width < 32: - if value.dtype.is_float: - value = value.to(Float32) - else: - if value.signed: - value = value.to(Int32) - else: - value = value.to(Uint32) - return orig_type( - nvvm.shfl_sync( - type(value).mlir_type, - Int32(mask).ir_value(loc=loc, ip=ip), - value.ir_value(loc=loc, ip=ip), - Int32(offset).ir_value(loc=loc, ip=ip), - Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), - kind, - loc=loc, - ip=ip, - ) - ) - elif value.width == 32: - return orig_type( - nvvm.shfl_sync( - type(value).mlir_type, - Int32(mask).ir_value(loc=loc, ip=ip), - value.ir_value(loc=loc, ip=ip), - Int32(offset).ir_value(loc=loc, ip=ip), - Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), - kind, - loc=loc, - ip=ip, - ) - ) - else: - if value.width != 64: - raise ValueError( - "shuffle_sync only supports 64 bits values when the bit width is larger than 32" - ) - value = llvm.bitcast( - T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip - ) - # extract low 32 bits - low_32_bits = llvm.trunc( - T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip - ) - # extract high 32 bits - high_32_bits = llvm.lshr( - value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - high_32_bits = llvm.trunc( - T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip - ) - - low_32_bits_shfl = nvvm.shfl_sync( - T.i32(), - Int32(mask).ir_value(loc=loc, ip=ip), - low_32_bits, - Int32(offset).ir_value(loc=loc, ip=ip), - Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), - kind, - loc=loc, - ip=ip, - ) - high_32_bits_shfl = nvvm.shfl_sync( - T.i32(), - Int32(mask).ir_value(loc=loc, ip=ip), - high_32_bits, - Int32(offset).ir_value(loc=loc, ip=ip), - Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), - kind, - loc=loc, - ip=ip, - ) - - # combine low and high 32 bits - low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) - high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) - shlf_res = llvm.shl( - high_64_bit, - Int64(32).ir_value(loc=loc, ip=ip), - llvm.IntegerOverflowFlags.none, - loc=loc, - ip=ip, - ) - shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip) - shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) - return orig_type(shlf_res) - -shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx) -shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up) -shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down) -shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly) - - -@dsl_user_op -def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None: - """ - Creates a barrier, optionally named. - """ - if barrier_id is not None: - barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) - - if number_of_threads is not None: - number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) - - nvvm.barrier( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) - - -@dsl_user_op -def barrier_arrive( - *, barrier_id=None, number_of_threads=None, loc=None, ip=None -) -> None: - if barrier_id is not None: - barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) - - if number_of_threads is None: - raise ValueError( - "barrier_arrive needs pass number_of_threads to arrive the barrier", - ) - number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) - - nvvm.barrier_arrive( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) - - -@dsl_user_op -def sync_threads(*, loc=None, ip=None) -> None: - """ - Synchronizes all threads within a CTA. - """ - nvvm.barrier(loc=loc, ip=ip) - - -@dsl_user_op -def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None: - """ - Performs a warp-wide sync with an optional mask. - """ - nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - - -@dsl_user_op -def fence_acq_rel_cta(*, loc=None, ip=None) -> None: - """ - Fence operation with acquire-release semantics. - - See the `PTX documentation `__. - """ - nvvm.fence_acq_rel_cta(loc=loc, ip=ip) - - -@dsl_user_op -def fence_acq_rel_cluster(*, loc=None, ip=None) -> None: - """ - Fence operation with acquire-release semantics. - - See the `PTX documentation `__. - """ - nvvm.fence_acq_rel_cluster(loc=loc, ip=ip) - - -@dsl_user_op -def fence_acq_rel_gpu(*, loc=None, ip=None) -> None: - """ - Fence operation with acquire-release semantics. - - See the `PTX documentation `__. - """ - nvvm.fence_acq_rel_gpu(loc=loc, ip=ip) - - -@dsl_user_op -def fence_acq_rel_sys(*, loc=None, ip=None) -> None: - """ - Fence operation with acquire-release semantics. - - See the `PTX documentation `__. - """ - nvvm.fence_acq_rel_sys(loc=loc, ip=ip) - - -@dsl_user_op -def cp_async_commit_group(*, loc=None, ip=None) -> None: - """ - Commits all prior initiated but uncommitted cp.async instructions. - - See the `PTX documentation `__. - """ - nvvm.cp_async_commit_group(loc=loc, ip=ip) - - -@dsl_user_op -def cp_async_wait_group(n, *, loc=None, ip=None) -> None: - """ - Waits till only a specified numbers of cp.async groups are pending. - - See the `PTX documentation `__. - """ - nvvm.cp_async_wait_group(n, loc=loc, ip=ip) - - -@dsl_user_op -def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None: - """ - Commits all prior initiated but uncommitted cp.async.bulk instructions. - - See the `PTX documentation `__. - """ - nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip) - - -@dsl_user_op -def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None: - """ - Waits till only a specified numbers of cp.async.bulk groups are pending. - - See the `PTX documentation `__. - """ - nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip) - - -@dsl_user_op -def cluster_wait(*, loc=None, ip=None) -> None: - """ - A cluster-wide wait operation. - """ - nvvm.cluster_wait(loc=loc, ip=ip) - - -@dsl_user_op -def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None: - """ - A cluster-wide arrive operation. - """ - nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip) - - -@dsl_user_op -def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None: - """ - A cluster-wide arrive operation with relaxed semantics. - """ - nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip) - - -@dsl_user_op -def fence_proxy( - kind: ProxyKind, - *, - space: Optional[SharedSpace] = None, - use_intrinsic=None, - loc=None, - ip=None, -) -> None: - nvvm.fence_proxy( - kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip - ) - - -@dsl_user_op -def vote_ballot_sync( - pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None -) -> Int32: - """ - Performs a ballot operation across the warp. - """ - return Int32( - nvvm.vote_ballot_sync( - T.i32(), - Int32(mask).ir_value(loc=loc, ip=ip), - Boolean(pred).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -@dsl_user_op -def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: - """ - Performs a population count operation. - """ - if not isinstance(value, Numeric): - value = as_numeric(value) - return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip)) - - -@dsl_user_op -def fence_view_async_tmem_op( - kind: Tcgen05WaitKind, - *, - loc=None, - ip=None, -) -> None: - """ - Perform a fence operation on the async TMEM load or store. - - .. note:: - This function is only available on sm_100a and above. - The fence is required to synchronize the TMEM load/store - and let the pipeline release or commit the buffer. - - Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM. - ``` - # Start to copy ACC from TMEM to register - cute.copy(tmem_load, tACC, rACC) - fence_view_async_tmem_load() - # After fence, we can ensure the TMEM buffer is consumed totally. - # Release the buffer to let the MMA know it can overwrite the buffer. - mma2accum_pipeline.consumer_release(curr_consumer_state) - ``` - Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM. - ``` - # Start to copy A from register to TMEM - cute.copy(tmem_store, rA, tA) - fence_view_async_tmem_store() - # After fence, we can ensure the TMEM buffer is ready. - # Commit the buffer to let the MMA know it can start to load A. - tmem_mma_pipeline.producer_commit(curr_producer_state) - ``` - - - :param kind: The kind of fence operation to perform including LOAD and STORE. - :type kind: Tcgen05WaitKind - """ - nvvm.tcgen05_wait(kind, loc=loc, ip=ip) - - -fence_view_async_tmem_load = partial( - fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD -) -fence_view_async_tmem_store = partial( - fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE -) - - -@dsl_user_op -def warpgroup_reg_realloc_op( - reg_count: int, - kind: SetMaxRegisterAction, - *, - loc=None, - ip=None, -) -> None: - nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip) - - -warpgroup_reg_alloc = partial( - warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase -) -warpgroup_reg_dealloc = partial( - warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease -) - - -@dsl_user_op -def calc_packed_f32x2_op( - src_a: Tuple[Float32, Float32], - src_b: Tuple[Float32, Float32], - src_c: Tuple[Float32, Float32] | None, - calc_func: Callable, - *, - rnd=RoundingModeKind.RZ, - ftz=True, - loc=None, - ip=None, -) -> Tuple[Float32, Float32]: - vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) - vec_src_a = vector.from_elements( - vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip - ) - vec_src_b = vector.from_elements( - vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip - ) - if src_c is not None: - vec_src_c = vector.from_elements( - vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip - ) - vec_res = calc_func( - vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip - ) - else: - vec_res = calc_func( - vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip - ) - - res0 = Float32( - vector.extract( - vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip - ) - ) - res1 = Float32( - vector.extract( - vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip - ) - ) - return res0, res1 - - -fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2) -mul_packed_f32x2 = partial( - calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2 -) -add_packed_f32x2 = partial( - calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2 -) - - -@dsl_user_op -def fmax( - a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None -) -> Float32: - return Float32( - nvvm.fmax( - T.f32(), - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -@dsl_user_op -def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): - return Float32( - nvvm.rcp_approx_ftz_f( - T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - ) - - -@dsl_user_op -@deprecated( - "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead" -) -def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: - return Float32( - llvm.inline_asm( - T.f32(), - [Float32(a).ir_value(loc=loc, ip=ip)], - "ex2.approx.ftz.f32 $0, $1;", - "=f,f", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -@deprecated( - "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" -) -def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: - LOG2_E = 1.4426950408889634 - return exp2(a * LOG2_E, loc=loc, ip=ip) - - -@dsl_user_op -@deprecated( - "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead" -) -def exp_packed_f32x2( - a: Tuple[Float32, Float32], *, loc=None, ip=None -) -> Tuple[Float32, Float32]: - LOG2_E = Float32(1.4426950408889634) - b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) - return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py deleted file mode 100644 index 37f87ea64d7f7482f3b2f464be6a0ee1a2e3494f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Optional, Type - -from cutlass.cutlass_dsl import T, dsl_user_op - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ..typing import Pointer, Numeric, NumericMeta - - -@dsl_user_op -def alloc_smem( - element_type: Type[Numeric], - size_in_elems: int, - alignment: Optional[int] = None, - *, - loc=None, - ip=None, -) -> Pointer: - """ - Statically allocates SMEM. - - :param element_type: The pointee type of the pointer. - :type element_type: Type[Numeric] - :param size_in_elems: The size of the allocation in terms of number of elements of the - pointee type - :type size_in_elems: int - :param alignment: An optional pointer alignment for the allocation - :type alignment: int - :return: A pointer to the start of the allocation - :rtype: Pointer - """ - if not isinstance(element_type, NumericMeta): - raise TypeError( - f"element_type must be a type of Numeric, but got {element_type}" - ) - - if alignment is None: - # Default alignment based on the element type's width - alignment = element_type.width // 8 - ptr_ty = _cute_ir.PtrType.get( - element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment - ) - return _cute_nvgpu_ir.arch_alloc_smem( - ptr=ptr_ty, - input=ir.IntegerAttr.get(T.i32(), size_in_elems), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def get_dyn_smem( - element_type: Type[Numeric], - alignment: Optional[int] = None, - *, - loc=None, - ip=None, -) -> Pointer: - """ - Retrieves a pointer to a dynamic SMEM allocation. - - :param element_type: The pointee type of the pointer. - :type element_type: Type[Numeric] - :param alignment: An optional pointer alignment, the result pointer is offset appropriately - :type alignment: int - :return: A pointer to the start of the dynamic SMEM allocation with a correct - alignement - :rtype: Pointer - """ - if not isinstance(element_type, NumericMeta): - raise TypeError( - f"element_type must be a type of Numeric, but got {element_type}" - ) - - if alignment is None: - # Default alignment based on the element type's width - alignment = element_type.width // 8 - ptr_ty = _cute_ir.PtrType.get( - element_type.mlir_type, - _cute_ir.AddressSpace.smem, - alignment, - ) - return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip) - - -@dsl_user_op -def get_dyn_smem_size(*, loc=None, ip=None) -> int: - """ - Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time. - This can be used for bounds checking during shared memory allocation. - - :return: The size of dynamic shared memory in bytes - :rtype: int - """ - return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py deleted file mode 100644 index 302616d20b34ccfe1d3194e48bf94114eeafeaec..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py +++ /dev/null @@ -1,142 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Type - -from cutlass.cutlass_dsl import dsl_user_op - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir - -from ..typing import Pointer, Int, Int32, Numeric, NumericMeta - - -SM100_TMEM_CAPACITY_COLUMNS = 512 -SM100_TMEM_MIN_ALLOC_COLUMNS = 32 - - -@dsl_user_op -def retrieve_tmem_ptr( - element_type: Type[Numeric], - alignment: int, - ptr_to_buffer_holding_addr: Pointer, - *, - loc=None, - ip=None, -) -> Pointer: - """ - Retrieves a pointer to TMEM with the provided element type and alignment. - - :param element_type: The pointee type of the pointer. - :type element_type: Type[Numeric] - :param alignment: The alignment of the result pointer - :type alignment: int - :param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the - start of the allocation allocation - :type ptr_to_buffer_holding_addr: Pointer - :return: A pointer to TMEM - :rtype: Pointer - """ - if not isinstance(element_type, NumericMeta): - raise TypeError( - f"element_type must be a type of Numeric, but got {element_type}" - ) - - res_ty = _cute_ir.PtrType.get( - element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment - ) - return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr( - res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip - ) - - -@dsl_user_op -def alloc_tmem( - num_columns: Int, - smem_ptr_to_write_address: Pointer, - is_two_cta=None, - *, - loc=None, - ip=None, -) -> None: - """ - Allocates TMEM. - - :param num_columns: The number of TMEM columns to allocate - :type num_columns: Int - :param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written - to - :type smem_ptr_to_write_address: Pointer - :param is_two_cta: Optional boolean parameter for 2-CTA MMAs - """ - if isinstance(num_columns, int): - if ( - num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS - or num_columns > SM100_TMEM_CAPACITY_COLUMNS - or not (num_columns & (num_columns - 1) == 0) - ): - raise ValueError( - f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" - ) - _cute_nvgpu_ir.arch_sm100_alloc_tmem( - Int32(num_columns).ir_value(loc=loc, ip=ip), - smem_ptr_to_write_address.value, - is_two_cta=is_two_cta, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None: - """ - Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can - allocate. - """ - _cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit( - is_two_cta=is_two_cta, loc=loc, ip=ip - ) - - -@dsl_user_op -def dealloc_tmem( - tmem_ptr: Pointer, - num_columns: Int, - is_two_cta=None, - *, - loc=None, - ip=None, -) -> None: - """ - Deallocates TMEM using the provided pointer and number of columns. - - :param tmem_ptr: A pointer to the TMEM allocation to de-allocate - :type tmem_ptr: Pointer - :param num_columns: The number of columns in the TMEM allocation - :type num_columns: Int - :param is_two_cta: Optional boolean parameter for 2-CTA MMAs - """ - if isinstance(num_columns, int): - if ( - num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS - or num_columns > SM100_TMEM_CAPACITY_COLUMNS - or not (num_columns & (num_columns - 1) == 0) - ): - raise ValueError( - f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" - ) - _cute_nvgpu_ir.arch_sm100_dealloc_tmem( - tmem_ptr.value, - Int32(num_columns).ir_value(loc=loc, ip=ip), - is_two_cta=is_two_cta, - loc=loc, - ip=ip, - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py deleted file mode 100644 index 12d5e4221a3e6007656a9400966e84d8b9a25a79..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py +++ /dev/null @@ -1,7070 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import copy as py_copy -from dataclasses import dataclass -import inspect -import math -import operator -from abc import ABC, abstractmethod -from functools import lru_cache, partial, reduce -from inspect import isclass -from itertools import chain -from typing import ( - Callable, - Iterable, - overload, - List, - Tuple, - Union, - Type, - Any, - Dict, - Optional, -) -from enum import Enum, auto - -from cutlass.cutlass_dsl import ( - const, - T, - lru_cache_ir, - is_dynamic_expression, - for_generate, - yield_out, - if_generate, - extract_mlir_values, - new_from_mlir_values, - _binary_op_type_promote, - not_, - cutlass_arith, - dsl_user_op, -) - -from cutlass._mlir import ir -from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results -from cutlass._mlir.dialects import cute as _cute_ir -from cutlass._mlir.dialects.cute import ( - ScaledBasis as _ScaledBasis, - Ratio as _Ratio, -) - -from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import llvm, builtin, vector, arith - -from .typing import ( - Numeric, - Integer, - NumericMeta, - Boolean, - Int32, - Int8, - Int16, - Int32, - Int64, - Float32, - TFloat32, - Int, - IntTuple, - Shape, - Stride, - Coord, - Layout, - Tile, - Tiler, - XTuple, - Tensor, - Pointer, - AddressSpace, - as_numeric, -) - - -#################################################################################################### -# -# Internal IntTuple helpers -# -#################################################################################################### - - -def _get_typed_value(x): - if isinstance(x, Integer): - return ( - x.value.get_typed_value() if isinstance(x.value, IntValue) else x.ir_value() - ) - else: - return x - - -def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: - x = transform_leaf(_get_typed_value, x) - res_ty, dyn_elems = packer(x) - # <"0"> is deduced from type inference which should be removed for make_... operations - dyn_elems = [t for t in dyn_elems if not is_static(t)] - return op(res_ty, dyn_elems, loc=loc, ip=ip).result - - -def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: - _check_shape(shape) - return _pack_x(shape, _cute_ir.pack_shape, _cute_ir.MakeShapeOp, loc=loc, ip=ip) - - -def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: - _check_stride(stride) - # Convert basis elements to the base class before _pack_x - stride = transform_leaf( - lambda x: x.to(_cute_ir.ScaledBasis) if isinstance(x, ScaledBasis) else x, - stride, - ) - return _pack_x(stride, _cute_ir.pack_stride, _cute_ir.MakeStrideOp, loc=loc, ip=ip) - - -def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: - _check_coord(coord) - return _pack_x(coord, _cute_ir.pack_coord, _cute_ir.MakeCoordOp, loc=loc, ip=ip) - - -def _pack_int_tuple(int_tuple: IntTuple, *, loc=None, ip=None) -> ir.Value: - _check_int_tuple(int_tuple) - return _pack_x( - int_tuple, _cute_ir.pack_int_tuple, _cute_ir.MakeIntTupleOp, loc=loc, ip=ip - ) - - -def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: - _check_tile(tile) - - def expand_leaves(tile) -> list: - leaves = [] - for e in tile: - if isinstance(e, _Layout): - leaves.extend(list(flatten_to_tuple(e.shape))) - leaves.extend(list(flatten_to_tuple(e.stride))) - else: - leaves.append(e) - return leaves - - layout_leaves = flatten_to_tuple(tile) - dyn_elems = expand_leaves(layout_leaves) - dyn_elems = [ - _get_typed_value(x) for x in dyn_elems if isinstance(x, (Integer, ir.Value)) - ] - - res_ty = _cute_ir.pack_tile(tile) - return _cute_ir.make_tile(res_ty, dyn_elems, loc=loc, ip=ip) - - -def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple: - # If t is an MLIR type, make sure it's static and make a Value - if isinstance(t, ir.Type): - if not _cute_ir.is_static(t): - raise ValueError() - t = _cute_ir.static(t) - - if isinstance(t, ir.Value): - input_ty = t.type - if t.type.rank == 0: - # Handle this case separately, _cute_ir.get_leaves will return an Op in this case - vals = [] - else: - vals = _cute_ir.get_leaves(t, loc=loc, ip=ip) - if not isinstance(vals, list): - vals = [vals] - else: - raise TypeError(f"expects static type or value, but got {t}") - - # CuTe IR only supports Int32 for now. Need to support detection of other types - res = _cute_ir.unpack_x_tuple(input_ty, vals) - - def post_process(x): - if isinstance(x, _cute_ir.ScaledBasis): - return ScaledBasis(post_process(x.get_value()), x.get_mode()) - elif isinstance(x, _cute_ir.Ratio): - return Ratio(x.numerator, x.denominator) - else: - return x - - return transform_leaf(post_process, res) - - -#################################################################################################### -# Validation helpers -#################################################################################################### - - -def _check_shape(shape: Shape) -> None: - if is_integer(shape): - if isinstance(shape, int): - if shape <= 0: - raise ValueError( - f"Expected size in shape to be strictly positive, but got {shape}" - ) - elif isinstance(shape, Integer): - pass - else: - raise TypeError(f"Expected size be int or Integer, but got {type(shape)}") - elif isinstance(shape, tuple): - for s in shape: - _check_shape(s) - else: - raise ValueError( - f"Expected Shape, which is a positive integer or tuple of Shapes, but got {shape}" - ) - - -def _check_coord(coord: Coord) -> None: - flat_coord = flatten_to_tuple(coord) - if not all(is_integer(c) or c is None for c in flat_coord): - raise ValueError( - f"Expected Coord, whose leaves are integers or None, but got {coord}" - ) - - -def _check_stride(stride: Stride) -> None: - flat_stride = flatten_to_tuple(stride) - if not all(is_integer(s) or isinstance(s, ScaledBasis) for s in flat_stride): - raise ValueError( - f"Expected Stride, whose leaves are integers or ScaledBasis, but got {stride}" - ) - - -def _check_int_tuple(int_tuple: IntTuple) -> None: - flat_int_tuple = flatten_to_tuple(int_tuple) - if not all(is_integer(d) for d in flat_int_tuple): - raise ValueError( - f"Expected IntTuple, whose leaves are integers, but got {int_tuple}" - ) - - -def _check_tile(tile: Tile) -> None: - flat_tile = flatten_to_tuple(tile) - if not all(is_integer(t) or isinstance(t, _Layout) or t is None for t in flat_tile): - raise ValueError( - f"Expected Tile, whose leaves are integers or Layout or None, but got {tile}" - ) - - -#################################################################################################### -# -# Core types -# -#################################################################################################### - - -class IntValue(cutlass_arith.ArithValue): - """Internal representation of constrained integer types with divisibility information. - - IntValue serves as a proxy for constrained integer types in the CuTe IR. Rather than - directly storing values of IntTupleType with depth=0, it stores the result of the - `cute.get_scalars` operation applied to such values. - - This class represents the following sequence of operations in the IR: - %0 = ... : (...) -> !cute.int_tuple<"?"> - %1 = cute.get_scalars(%0) : (!cute.int_tuple<"?">) -> i32 - - where the first operation produces a `cute.int_tuple<"?">` with depth=0 and rank=1. It - automatically emit `cute.get_scalars` and track it. - - IntValue inherits behavior from ArithValue with the following extensions: - * Overloaded operations that accept IntTupleType values to propagate divisibility information - * Support for CuTe operations that utilize divisibility constraints - - API for interacting with IntValue: - * get_typed_value() - Returns the value as an IntTupleType - * get_divisibility() - Returns the divisibility constraint of the value - """ - - def __init__(self, v, signed=True): - # Cute Constrained Int Type is always signed - if isinstance(v, int): - v = _pack_int_tuple(v) - - if isinstance(v.type, _cute_ir.IntTupleType): - scalar_val = _cute_ir.get_scalars(v) - super().__init__(scalar_val, True) - else: - super().__init__(v, True) - - def get_typed_value(self): - if isinstance(self.type, ir.IntegerType): - def_op = self.owner.operation - if def_op.name == "cute.get_scalars": - return def_op.operands[0] - - assert not isinstance(self.type, _cute_ir.IntTupleType) - - return _pack_int_tuple(self) - - @property - def divisibility(self): - if isinstance(self.get_typed_value().type, _cute_ir.IntTupleType): - return self.get_typed_value().type.get_divisibility([0]) - else: - return 1 - - def __str__(self): - if self.divisibility == 1: - return f"?" - else: - return f"?{{div={self.divisibility}}}" - - def __repr__(self): - parent_name = cutlass_arith.ArithValue.__name__ - return super().__str__().replace(parent_name, IntValue.__name__) - - def pretty_str(self): - return self.__str__() - - @staticmethod - def _binary_op(op): - def wrapper(self, other, **kwargs): - if isinstance(other, IntValue): - other_val = other.get_typed_value() - elif isinstance(other, ir.Value) and isinstance( - other.type, _cute_ir.IntTupleType - ): - other_val = other - elif isinstance(other, ir.Value) and isinstance(other.type, ir.IntegerType): - other = cutlass_arith.int_to_int(other, Int32, **kwargs) - other_val = _pack_int_tuple(other) - elif isinstance(other, (int, bool)): - other_val = _pack_int_tuple(int(other)) - else: - # Dispatch to `__rmul__` of `other` - return NotImplemented - - return IntValue(op(self, other_val, **kwargs)) - - return wrapper - - @dsl_user_op - @_binary_op - def __add__(self, other, *, loc=None, ip=None): - return _cute_ir.add_offset(self.get_typed_value(), other, loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __sub__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_sub(self.get_typed_value(), other, loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __mul__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_mul(self.get_typed_value(), other, loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_div(self.get_typed_value(), other, loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __mod__(self, other, *, loc=None, ip=None) -> cutlass_arith.ArithValue: - return _cute_ir.tuple_mod(self.get_typed_value(), other, loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __radd__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.add_offset(other, self.get_typed_value(), loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __rsub__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_sub(other, self.get_typed_value(), loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __rmul__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_mul(other, self.get_typed_value(), loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_div(other, self.get_typed_value(), loc=loc, ip=ip) - - @dsl_user_op - @_binary_op - def __rmod__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_mod(other, self.get_typed_value(), loc=loc, ip=ip) - - -class Ratio(_Ratio): - """A class representing a rational number as a ratio of two integers. - - Ratio is used in CuTe to represent exact fractional values that arise in - tensor layout operations, particularly in composition operations where - divisibility conditions may not be satisfied. - - :param numerator: The numerator of the ratio - :type numerator: int - :param denominator: The denominator of the ratio - :type denominator: int - :raises TypeError: If numerator or denominator are not integers - """ - - def __init__(self, numerator: int, denominator: int): - if not isinstance(numerator, int) or not isinstance(denominator, int): - raise TypeError( - f"numerator and denominator must be integers, but got {numerator} and {denominator}" - ) - super().__init__(numerator, denominator) - - def is_integral(self) -> bool: - """Check if the ratio represents an integer value. - - :return: True if the numerator is divisible by the denominator - :rtype: bool - """ - return super().is_integral() - - def reduced(self) -> "Ratio": - """Return a new Ratio with the numerator and denominator reduced to lowest terms. - - :return: A new Ratio in reduced form - :rtype: Ratio - """ - res = super().reduced() - return Ratio(res.numerator, res.denominator) - - def __mul__(self, other): - """Multiply this ratio by another ratio or an integer. - - :param other: The value to multiply by - :type other: Union[Ratio, int] - :return: A new ratio representing the product - :rtype: Ratio - :raises TypeError: If other is not a Ratio or int - """ - if isinstance(other, Ratio): - return Ratio( - self.numerator * other.numerator, - self.denominator * other.denominator, - ) - elif isinstance(other, int): - return Ratio(self.numerator * other, self.denominator) - else: - raise TypeError(f"Cannot multiply Ratio with {type(other)}") - - def __rmul__(self, other): - """Right multiplication operation. - - :param other: The value to multiply by - :type other: Union[Ratio, int] - :return: A new ratio representing the product - :rtype: Ratio - """ - return self.__mul__(other) - - def __str__(self): - """String representation of the ratio. - - :return: String in the format "numerator/denominator" - :rtype: str - """ - return super().__str__() - - def to(self, dtype): - """Convert the ratio to another type. - - :param dtype: The target type for conversion - :type dtype: type - :return: The ratio converted to the specified type - :raises TypeError: If conversion to the specified type is not supported - """ - if dtype is Ratio: - return self - elif dtype is float: - return self.numerator / self.denominator - elif dtype is int: - return self.numerator // self.denominator - elif issubclass(dtype, _Ratio): - return self - else: - raise TypeError(f"Cannot convert Ratio to {dtype}") - - -class ScaledBasis: - """A class representing a scaled basis element in CuTe's layout algebra. - - ScaledBasis is used to represent elements in the layout algebra, particularly - in the context of composition operations. It consists of a value (scale) and - a mode that identifies mode of the basis element. - - :param value: The scale value - :type value: Union[int, Integer, Ratio, ir.Value] - :param mode: The mode identifying the basis element - :type mode: Union[int, List[int]] - :raises TypeError: If mode is not an integer or list of integers - - **Examples:** - - .. code-block:: python - - # Create a scaled basis with integer scale and mode - sb1 = ScaledBasis(2, 0) # 2 * E(0) - - # Create a scaled basis with a Ratio scale - sb2 = ScaledBasis(Ratio(1, 2), 1) # (1/2) * E(1) - - # Create a scaled basis with a list of modes - sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1]) - - # Scaled basis elements are commonly used in layout strides - layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1))) - - # This creates a layout with strides (2@0, 1@1) representing - # a coordinate system where each dimension has its own basis - - # Example: Mapping coordinates to indices using the layout - coord = (2, 3) - idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3) - """ - - def __init__(self, value, mode) -> None: - if isinstance(mode, int): - self._mode = [mode] - else: - if any(not isinstance(x, int) for x in mode): - raise TypeError("Mode must be a list of integers") - self._mode = mode - - self._value = value - - def is_static(self) -> bool: - """Check if the value is statically known. - - :return: True if the value is not a dynamic expression - :rtype: bool - """ - return not is_dynamic_expression(self._value) - - def to(self, dtype): - """Convert to another type. - - :param dtype: The target type for conversion - :type dtype: type - :return: The ScaledBasis converted to the specified type - :raises TypeError: If conversion to the specified type is not supported - """ - if dtype is ScaledBasis: - return self - elif dtype is _ScaledBasis: - if isinstance(self._value, Ratio): - scale = self._value - elif isinstance(self._value, Integer): - scale = self._value.ir_value() - else: - scale = self._value - - if isinstance(scale, IntValue): - return _ScaledBasis(scale.get_typed_value(), self._mode) - else: - return _ScaledBasis(scale, self._mode) - else: - raise TypeError(f"Cannot convert ScaledBasis to {dtype}") - - def __str__(self): - return f"{self.to(_ScaledBasis).__str__()}" - - def __hash__(self): - if isinstance(self.mode, list): - return hash((self.value, tuple(self.mode))) - else: - return hash((self.value, self.mode)) - - @property - def value(self): - """Get the scale value. - - :return: The scale value - """ - return self._value - - @property - def mode(self) -> List[int]: - """Get the mode identifying the basis element. - - :return: The mode as a list of integers - :rtype: List[int] - """ - return self._mode - - def __eq__(self, other): - if isinstance(other, ScaledBasis): - return self.value == other.value and self.mode == other.mode - else: - return False - - def __rmul__(self, scale: Union[Int, ir.Value, Ratio]) -> "ScaledBasis": - """Right multiplication by a scale factor. - - This operation is used in layout algebra to scale basis elements, - which is essential for operations like composition and partitioning. - - :param scale: The scale factor - :type scale: Union[Int, ir.Value, Ratio] - :return: A new scaled basis element - :rtype: ScaledBasis - :raises TypeError: If scale is not of a supported type - :raises NotImplementedError: If scaling a basis element with a ratio value - """ - if not isinstance(scale, (int, Integer, Ratio, ir.Value)): - raise TypeError( - f"scale must be an integer or a ratio, but got {type(scale)}" - ) - if isinstance(self.value, Ratio): - raise NotImplementedError( - "scaling a basis element having a ratio is not supported" - ) - - value = self.value - - if not isinstance(value, (Integer, Ratio, int, cutlass_arith.ArithValue)): - raise TypeError(f"Don't support {type(value)} for ScaledBasis") - - # Lift to IntValue type to preserve type info as much as possible - if isinstance(scale, cutlass_arith.ArithValue): - scale = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(scale, Int32))) - - if isinstance(value, cutlass_arith.ArithValue): - value = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(value, Int32))) - elif isinstance(value, Integer): - value = value.ir_value() - - return ScaledBasis(scale * value, self.mode) # type: ignore - - -def E(mode: Union[int, List[int]]) -> ScaledBasis: - """Create a unit ScaledBasis element with the specified mode. - - This function creates a ScaledBasis with value 1 and the given mode. - The mode represents the coordinate axis or dimension in the layout. - - :param mode: The mode (dimension) for the basis element, either a single integer or a list of integers - :type mode: Union[int, List[int]] - :return: A ScaledBasis with value 1 and the specified mode - :rtype: ScaledBasis - :raises TypeError: If mode is not an integer or a list - - **Examples:** - - .. code-block:: python - - # Create a basis element for the first dimension (mode 0) - e0 = E(0) - - # Create a basis element for the second dimension (mode 1) - e1 = E(1) - - # Create a basis element for a hierarchical dimension - e_hier = E([0, 1]) - """ - if isinstance(mode, int): - mode = [mode] - - if not isinstance(mode, list): - raise TypeError(f"expects a list, got {type(mode)}") - - if not mode: - return 1 - - return ScaledBasis(1, mode) - - -def get_divisibility(x: Union[int, Integer]) -> int: - if isinstance(x, int): - return x - - if isinstance(x, Integer): - x = x.value - - if isinstance(x, IntValue): - return x.divisibility - else: - return 1 - - -@ir.register_value_caster(_cute_ir.SwizzleType.get_static_typeid(), replace=True) -class Swizzle(ir.Value): - """ - Swizzle is a transformation that permutes the elements of a layout. - - Swizzles are used to rearrange data elements to improve memory access patterns - and computational efficiency. - - Swizzle is defined by three parameters: - - MBase: The number of least-significant bits to keep constant - - BBits: The number of bits in the mask - - SShift: The distance to shift the mask - - The mask is applied to the least-significant bits of the layout. - - .. code-block:: - - 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx - ^--^ MBase is the number of least-sig bits to keep constant - ^-^ ^-^ BBits is the number of bits in the mask - ^---------^ SShift is the distance to shift the YYY mask - (pos shifts YYY to the right, neg shifts YYY to the left) - - e.g. Given - 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx - - the result is - 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY - - """ - - def __str__(self): - # Cut off the MLIR type's string for making pretty_str more concise - return self.type.__str__()[15 : 15 + 8] - - -@ir.register_value_caster(_cute_ir.LayoutType.get_static_typeid(), replace=True) -class _Layout(Layout): - """Layout is CuTe's core abstraction for representing tensor layouts. - - A Layout maps from a logical coordinate space to an index space, defined by a - pair of (Shape, Stride). The Shape defines the abstract dimensions of the Layout, - while the Stride defines how coordinates within the Shape map to linear indices. - - Layouts present a common interface to multidimensional array access that abstracts - away the details of how array elements are organized in memory. This allows algorithms - to be written generically, so that layouts can change without requiring code changes. - - CuTe layouts are inherently hierarchical, constructed from smaller, nested layouts - that can represent complex mappings required by GPU tensor instructions. They support - a rich algebra of operations including concatenation, coalescence, composition, - complement, and inversion. - - :ivar shape: An IntTuple representing the dimensions of the layout. - :ivar stride: An IntTuple representing the strides of the layout. - :ivar max_alignment: The maximum alignment of the layout. - - **Examples:** - - .. code-block:: python - - # Creating a layout with shape (4,8) and default stride (layout left / "column major") - layout = cute.make_layout((4, 8)) - - # Creating a layout with explicit shape and stride - layout = cute.make_layout((4, 8), stride=(8, 1)) - - # Accessing a specific coordinate: (2, 3) -> 2 * 8 + 3 * 1 = 19 - idx = cute.crd2idx((2, 3), layout) - """ - - def __init__(self, op_result) -> None: - """Initialize a Layout object. - - :param op_result: The operation result value to wrap. - """ - super().__init__(op_result) - - def __str__(self) -> str: - """Return a string representation of the layout. - - :return: A string in the format "shape:stride". - """ - return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}" - - @property - def shape(self, *, loc=None, ip=None) -> Shape: - """Get the shape of the layout. - - The shape defines the dimensions and structure of the layout's - coordinate space. - - :param loc: Optional location information for debugging. - :param ip: Optional insertion point for IR generation. - :return: The hierarchical shape of the layout. - """ - return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) - - @property - def stride(self, *, loc=None, ip=None) -> Stride: - """Get the stride of the layout. - - The stride defines how coordinates map to linear indices in memory. - - :param loc: Optional location information for debugging. - :param ip: Optional insertion point for IR generation. - :return: The hierarchical stride of the layout. - """ - return _unpack_x_tuple( - _cute_ir.get_stride(self, loc=loc, ip=ip), loc=loc, ip=ip - ) - - @property - def max_alignment(self) -> int: - """Get the maximum alignment of the layout. - - :return: The maximum alignment in bytes. - """ - return self.type.max_alignment - - def __eq__(self, other) -> Union[bool, Boolean]: - """Check if this layout is equal to another layout. - - Two layouts are equal if they have the same shape and stride. - - :param other: The layout to compare with. - :return: True if layouts are equal, False otherwise. - May return an IR value for dynamic layouts. - """ - if isinstance(other, Layout): - if is_static(self.type) and is_static(other.type): - return self.type == other.type - return Boolean(_cute_ir.equal(self, other)) - else: - return False - - def __req__(self, other) -> Union[bool, Boolean]: - """Reflected equality check. - - :param other: The layout to compare with. - :return: Result of other.__eq__(self). - """ - if isinstance(other, Layout): - return other.__eq__(self) - return False - - def __ne__(self, other) -> Union[bool, Boolean]: - """Check if this layout is not equal to another layout. - - :param other: The layout to compare with. - :return: True if layouts are not equal, False otherwise. - """ - if isinstance(other, Layout): - if is_static(self.type) and is_static(other.type): - return self.type != other.type - return Boolean(not_(_cute_ir.equal(self, other))) - else: - return True - - def __rne__(self, other) -> Union[bool, Boolean]: - """Reflected inequality check. - - :param other: The layout to compare with. - :return: Result of other.__ne__(self). - """ - if isinstance(other, Layout): - return other.__ne__(self) - return False - - def __getitem__(self, idx: int) -> Layout: - """ - Top-level `get` to provide a syntax similar to `tuple`. - """ - return get(self, mode=[idx]) - - @dsl_user_op - def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: - return crd2idx(coord, self, loc=loc, ip=ip) - - @dsl_user_op - def get_hier_coord(self, idx, *, loc=None, ip=None) -> Coord: - """Get the hierarchical coordinate corresponding to a linear index. - - This method maps from a linear index back to the logical coordinate - in the layout's coordinate space. - - :param idx: The linear index to convert. - :return: The hierarchical coordinate corresponding to the index. - - **Examples:** - - .. code-block:: python - - layout = make_layout((4, 8), stride=(8, 1)) - - # map linear index back to coordinate: 5 -> (1, 1) - coord = get_hier_coord(5, layout) - """ - idx_val = Int32(idx).ir_value() - crd = _cute_ir.get_hier_coord(idx_val, self, loc=loc, ip=ip) - return _unpack_x_tuple(crd) - - @dsl_user_op - def get_flat_coord(self, idx, *, loc=None, ip=None) -> Coord: - idx_val = Int32(idx).ir_value() - res = _cute_ir.get_flat_coord(idx_val, self, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - - -@ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True) -class ComposedLayout(ir.Value): - r"""ComposedLayout represents the functional composition of layouts in CuTe. - - A ComposedLayout is formed by the composition of three components: - inner o offset o outer, where: - - - inner: The inner layout or swizzle that is applied last - - offset: An integer tuple representing a coordinate offset - - outer: The outer layout that is applied first - - ComposedLayout implements the functional composition operation where: - - .. math:: - - R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c)) - - This composition allows for complex transformations of coordinates and indices, - enabling operations like tiling, partitioning, and reshaping of data. - - :ivar inner: The inner layout or swizzle component - :ivar offset: The coordinate offset applied between inner and outer layouts - :ivar outer: The outer layout component - :ivar max_alignment: The maximum alignment of the composed layout - - **Examples:** - - .. code-block:: python - - # Create a composed layout with inner layout, offset, and outer layout - - # inner layout: (4, 8):(1, 4) - inner_layout = make_layout((4, 8)) - - offset = (0, 0) - - # outer layout: (2, 2):(1@0, 1@1) - outer_layout = make_layout((2, 2), stride=(1 * E(0), 1 * E(1))) - - # composed layout: (inner o offset o outer) - composed = make_composed_layout(inner_layout, offset, outer_layout) - - # Accessing components of the composed layout - inner = composed.inner - offset = composed.offset - outer = composed.outer - - # map coordinate (0, 1) to linear index - # - outer(0, 1) = (0, 1) - # - offset + outer(0, 1) = (0, 1) - # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 - idx = crd2idx((0, 1), composed) - - # Composition is used in many tiling operations - # For example, in logical_product, raked_product, and blocked_product - """ - - def __init__(self, value) -> None: - """Initialize a ComposedLayout object. - - :param value: The operation result value to wrap. - """ - super().__init__(value) - - def __str__(self) -> str: - return f"{pretty_str(self.inner)} o {pretty_str(self.offset)} o {pretty_str(self.outer)}" - - @property - def inner(self, *, loc=None, ip=None) -> Union[Swizzle, Layout]: - return _cute_ir.composed_get_inner(self, loc=loc, ip=ip) - - @property - def offset(self, *, loc=None, ip=None) -> IntTuple: - return _unpack_x_tuple(_cute_ir.composed_get_offset(self, loc=loc, ip=ip)) - - @property - def outer(self, *, loc=None, ip=None) -> Layout: - return _cute_ir.composed_get_outer(self, loc=loc, ip=ip) - - @property - def shape(self, *, loc=None, ip=None) -> Shape: - return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) - - @property - def max_alignment(self) -> int: - return self.type.max_alignment - - def __eq__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): - if is_static(self.type) and is_static(other.type): - return self.type == other.type - else: - raise NotImplementedError( - f"runtime comparison of composed layouts is not supported, got `{self}` and `{other}`" - ) - else: - return False - - def __req__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): - return Boolean(other.__eq__(self)) - return False - - def __ne__(self, other) -> Union[bool, Boolean]: - return not self.__eq__(other) - - def __rne__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): - return other.__ne__(self) - return False - - def __getitem__(self, idx: int) -> "ComposedLayout": - """ - Top-level `get` to provide a syntax similar to `tuple`. - """ - return get(self, mode=[idx]) - - @dsl_user_op - def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: - return crd2idx(coord, self, loc=loc, ip=ip) - - -@ir.register_value_caster(_cute_ir.PtrType.get_static_typeid(), replace=True) -class _Pointer(Pointer): - """ - A pointer class representing a memory address with specific properties. - - Pointers are a fundamental type of iterator/engine that support random-access operations. - They can be offset by elements of a layout's codomain and dereferenced to produce values. - - :param value: The MLIR operation result value to initialize the pointer with - :type value: ir.Value - - :ivar type: The MLIR type of the pointer - :vartype type: Type - :ivar value_type: The type of value this pointer points to - :vartype value_type: Type - :ivar memspace: The memory space where the pointer data resides (e.g., gmem, smem, rmem) - :vartype memspace: AddressSpace - - :note: When composed with a layout, a pointer forms a tensor: T = E ∘ L, where E is the pointer - and L is the layout. The tensor evaluates the layout by mapping a coordinate c to the - codomain, offsets the pointer accordingly, and dereferences the result: - T(c) = (E ∘ L)(c) = *(E + L(c)) - """ - - def __init__(self, value) -> None: - assert isinstance(value, ir.Value) - self.value = ir.Value(value) - - def __str__(self) -> str: - # Cut off the MLIR type's string for making pretty_str more concise - return self.type.__str__()[6:] - - def __get_mlir_types__(self): - return [self.value.type] - - def __extract_mlir_values__(self): - return [self.value] - - def __new_from_mlir_values__(self, values): - # Only expecting single value of _Pointer instance or ir.Value - # In this context, a _Pointer instance is an encapsulated ir.Value which is automatically created - # by value caster for cute.ptr typed values - assert len(values) == 1, f"Expected 1 value, but got {len(values)}" - assert isinstance( - values[0], (_Pointer, ir.Value) - ), f"Expected _Pointer or ir.Value, but got {type(values[0])}" - return _Pointer( - values[0] if isinstance(values[0], ir.Value) else values[0].value - ) - - @property - @lru_cache_ir() - def dtype(self) -> Type[Numeric]: - return Numeric.from_mlir_type(self.value.type.value_type) - - @property - def alignment(self) -> int: - return self.type.alignment - - @property - def max_alignment(self) -> int: - return self.type.max_alignment - - @property - @lru_cache_ir() - def memspace(self) -> AddressSpace: - return AddressSpace(self.type.address_space) - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - - # Only use if you absolutely need to get the LLVM pointer Value - @property - @lru_cache_ir() - def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: - """ - Get the LLVM pointer representation of this pointer. - - :param loc: The source location for the operation, defaults to None - :type loc: Location, optional - :param ip: The insertion point for the operation, defaults to None - :type ip: InsertionPoint, optional - :return: The LLVM pointer representation - :rtype: ir.Value - """ - llvm_ptr_ty = llvm.PointerType.get(self.memspace.value) - return builtin.unrealized_conversion_cast( - [llvm_ptr_ty], [self.value], loc=loc, ip=ip - ) - - def __add__(self, offset: IntTuple) -> Pointer: - """ - Offset the pointer by elements of a layout's codomain. - - :param offset: The offset to add to the pointer - :type offset: IntTuple - :return: A new pointer offset by the specified amount - :rtype: ir.Value - """ - offset = _pack_int_tuple(offset) - return _cute_ir.add_offset(self.value, offset=offset) - - @dsl_user_op - def toint(self, *, loc=None, ip=None): - if self.memspace in (AddressSpace.gmem, AddressSpace.generic): - res_type = Int64 - else: - res_type = Int32 - - return res_type( - _cute_ir.ptrtoint(res_type.mlir_type, self.value, loc=loc, ip=ip) - ) - - @dsl_user_op - def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: - """ - Align a pointer to a specified byte alignment. - - :param min_align: The minimum byte alignment requirement. Must be a power of 2. - :type min_align: int - :param loc: The source location for the operation, defaults to None - :type loc: Location, optional - :param ip: The insertion point for the operation, defaults to None - :type ip: InsertionPoint, optional - :return: The aligned new pointer that satisfies alignment request. - :rtype: Pointer - :raises ValueError: If the alignment is not a power of 2. - :raises TypeError: If pointer is in tmem address space. - """ - - if (min_align & (min_align - 1)) != 0: - raise ValueError("Alignment must be a power of 2") - - assert isinstance(self.type, _cute_ir.PtrType) - if self.memspace is AddressSpace.tmem: - raise ValueError("aligning a TMEM pointer is not supported") - - if min_align <= self.alignment: - return self - - dtype = Numeric.from_mlir_type(self.type.value_type) - # Convert pointer to integer - address_int = self.toint(loc=loc, ip=ip) - # Align the address - aligned_address = (address_int + min_align - 1) & ~(min_align - 1) - - return make_ptr( - dtype, - aligned_address, - self.memspace, - assumed_align=min_align, - loc=loc, - ip=ip, - ) - - -@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) -@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) -@ir.register_value_caster( - _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True -) -class _Tensor(Tensor): - """A tensor class representing the composition of an iterator (engine) with a layout. - - A tensor evaluates the layout by mapping a coordinate to the codomain, offsets the - iterator accordingly, and dereferences the result to obtain the tensor's value. - Formally: T(c) = (E ∘ L)(c) = *(E + L(c)), where E is the iterator/engine and L is the layout. - - :param value: The MLIR operation result value to initialize the tensor with - :type value: ir.Value - :param dtype: The user specified data type of the tensor elements. It could be \ - different from the underlying dtype in the iterator. The default is None. - :type dtype: Type[Numeric], optional - - Attributes: - iterator: The pointer or iterator (engine) component of the tensor - layout: The layout component defining the mapping from coordinates to offsets - shape: The shape of the tensor, inherited from the layout - stride: The stride of the tensor, inherited from the layout - element_type: The data type of the tensor elements - memspace: The memory space where the tensor data resides - - Notes: - - The tensor supports both direct element access via coordinates and slicing operations - - Load/store operations are only supported for specific memory spaces (rmem, smem, gmem, generic) - - For composed layouts, stride information is not directly accessible - - Dynamic layouts do not support vector load/store operations - - **Examples:** - - .. code-block:: python - - # Create a tensor with shape (4,8) in row-major layout - tensor = make_tensor(ptr, make_layout(shape=(4,8), stride=(8,1))) - - # Access individual element - val = tensor[0, 0] # or val = tensor[(0, 0)] - - # Slice operation - get first column - subtensor = tensor[None, 0] # or subtensor = tensor[(None, 0)] - """ - - def __init__(self, value, dtype: Optional[Type[Numeric]] = None): - self._dtype = dtype - if isinstance(value, ir.Value): - self.value = value - elif isinstance(value, _Tensor): - self.value = value.value - else: - raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}") - - # Set iterator - iter_val = _cute_ir.get_iter(self.value) - if isinstance(iter_val, Pointer): - self._iterator = iter_val - elif isinstance(iter_val.type, _cute_ir.IntTupleType): - self._iterator = _unpack_x_tuple(iter_val) - elif isinstance(iter_val, ir.Value): - # Example: SMEM descriptor iterator, not well supported today - self._iterator = iter_val - else: - raise TypeError(f"unsupported iterator type, got {type(iter_val)}") - - # Set dtype - if self._dtype is None: - if is_int_tuple(self.iterator): - self._dtype = IntTuple - elif isinstance(self.iterator, Pointer): - self._dtype = self.iterator.value_type - elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): - # SmemDescViewType do not need dtype - self._dtype = None - else: - raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") - - def __str__(self): - return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" - - def __extract_mlir_values__(self): - return [self.value] - - def __new_from_mlir_values__(self, values): - # Only expecting single value of _Tensor or ir.Value - # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created - # by value caster for MemRef/CoordTensor/SmemDescView typed values - assert len(values) == 1, f"Expected 1 value, but got {len(values)}" - assert isinstance( - values[0], (_Tensor, ir.Value) - ), f"Expected _Tensor or ir.Value, but got {type(values[0])}" - return _Tensor( - values[0] if isinstance(values[0], ir.Value) else values[0].value, - dtype=self.element_type, - ) - - # Cheat to let `Type(_Tensor())` to return cute.Tensor - @property - def __class__(self) -> Type[Tensor]: - return Tensor - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - - @dsl_user_op - def __getitem__( - self, crd: Coord, *, loc=None, ip=None - ) -> Union[Tensor, Numeric, IntTuple]: - """Access or slice tensor elements using coordinates. - - This method implements - * tensor evaluation T(c) = *(E + L(c)) when `c` is a coordinate without slicing, or - * tensor slicing operations T(c) = make_tensor(E + L(c), slice(L, c)) - where E is the iterator/engine and L is the layout - - :param crd: Coordinate or slice specification for accessing tensor elements - :type crd: Coord - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Tensor element value or sliced subtensor - :rtype: Union[Tensor, ir.Value, IntTuple] - - :raises ValueError: If coordinate access is invalid for the tensor layout - - **Examples:** - - .. code-block:: python - - # Create a tensor with pointer iterator - ptr = make_ptr(cutlass.Float32, 0, cutlass.AddressSpace.gmem) - layout = make_layout((64, 128)) # leftmost mode is major - tensor = make_tensor(ptr, layout) # Tensor using pointer iterator - - # Direct element access loads from memory - val = tensor[0] # Loads element at offset 0 - val = tensor[1] # Loads element at offset 4 (4bytes per Float32) - val = tensor[(0, 1)] # Loads element at offset 64 - - # Create a coord tensor - layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) - tensor = make_tensor((128, 128), layout) - - # Direct element access - val = tensor[0] # Returns (128, 128) - val = tensor[(0, 1)] # Returns (128, 129) - - # Slice access - sliced = view[(3, None)] # Returns tensor slice - - .. note:: - Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar - dereference operations. Attempting to set individual elements of tensors with - these element types will result in errors. - - **Examples:** - - .. code-block:: python - - # Unsupported operations with sub-byte types: - ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - # The following will raise an error: - val = tensor[0] # Error: sub-byte scalar dereference not supported - - # Similarly for other sub-byte types: - ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - val = tensor[0] # Error: sub-byte scalar dereference not supported - """ - if has_underscore(crd): - return slice_(self.value, crd) - elif isinstance(self.type, _cute_ir.CoordTensorType): - res = _cute_ir.get_iter(slice_(self, crd).value, loc=loc, ip=ip) - return _unpack_x_tuple(res) - else: - self._check_can_load_store() - self._check_can_dereference() - - crd_val = _pack_coord(crd, loc=loc, ip=ip) - data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) - return self.element_type(data_val) - - def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): - orig_dtype = data.dtype - # Implicit upcast to wider type - if ( - data.dtype.is_same_kind(self.element_type) - and self.element_type.width >= data.dtype.width - ): - data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore - - if data.dtype.width != self.element_type.width: - raise ValueError( - f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " - f"to Tensor with element type {self.element_type}" - ) - - if data.dtype is Boolean and self.element_type is Boolean: - # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory - val = data.ir_value_int8() - else: - val = data.ir_value() - return val - - @dsl_user_op - def __setitem__( - self, - crd: Coord, - data: Union[int, float, ir.Value, Numeric, "TensorSSA"], - *, - loc=None, - ip=None, - ) -> None: - """Set tensor elements at specified coordinates. - - Assigns values to tensor elements through direct coordinate access or slice assignment. - For slice assignment, the value must be a TensorSSA with matching shape. - - :param crd: Coordinate or slice specification for tensor element assignment - :type crd: Coord - :param data: Value to assign - can be scalar or TensorSSA for slice assignment - :type data: Union[int, float, ir.Value, Numeric, TensorSSA] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises ValueError: If tensor type doesn't support load/store operations - :raises ValueError: If slice assignment value is not a TensorSSA - :raises ValueError: If value type doesn't match tensor element type - :raises NotImplementedError: If value type is not supported - - .. note:: - Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar - dereference operations. Attempting to set individual elements of tensors with - these element types will result in errors. - - **Examples:** - - .. code-block:: python - - # Unsupported operations with sub-byte types: - ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - # The following will raise an error: - tensor[0] = 1.0 # Error: sub-byte scalar dereference not supported - - # Similarly for other sub-byte types: - ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - tensor[0] = 0.5 # Error: sub-byte scalar dereference not supported - """ - self._check_can_load_store() - - # convert scalar type - if not has_underscore(crd): - self._check_can_dereference() - # First, convert ir.Value to Numeric - if isinstance(data, ir.Value): - data = as_numeric(data) - elif isinstance(data, (int, float, bool)): - data = as_numeric(data) - - if not isinstance(data, Numeric): - raise ValueError(f"unsupported data type: {type(data)}") - - # Implicit upcast to wider type - val = self._cvt_to_dest(data, loc=loc, ip=ip) - if val.type != self.type.value_type: - raise ValueError( - f"type mismatch, store {val.type} to {self.element_type}" - ) - - crd_val = _pack_coord(crd, loc=loc, ip=ip) - _cute_ir.memref_store(self.value, crd_val, val, loc=loc, ip=ip) - else: - if not isinstance(data, TensorSSA): - raise ValueError(f"expects TensorSSA, but got {data}") - - self.__getitem__(crd).store(data, loc=loc, ip=ip) # type: ignore - - @property - def __class__(self) -> Type[Tensor]: - return Tensor - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - - @property - def iterator(self) -> Union[Pointer, IntTuple]: - return self._iterator - - @property - def layout(self) -> Layout: - return _cute_ir.get_layout(self.value) - - @property - def shape(self) -> Shape: - return self.layout.shape - - @property - def stride(self) -> Stride: - if isinstance(self.type, _cute_ir.ComposedLayoutType): - raise ValueError(f"can't get stride from composed layout") - return self.layout.stride - - @property - def leading_dim(self) -> Union[int, Tuple[int], None]: - """Get the leading dimension of this Tensor. - - :return: The index or indices of the first mode (from left to right) with stride 1 - :rtype: Union[int, Tuple[int], None] - :returns: - - int: Single leading dimension index if found - - Tuple[int]: Tuple of indices for nested leading dimensions - - None: If no leading dimension is found - - :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` - """ - return leading_dim(self.shape, self.stride) - - @property - @lru_cache_ir() - def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: - return self._dtype - - @property - @lru_cache_ir() - def memspace(self) -> AddressSpace: - if isinstance(self.iterator, Pointer): - return self.iterator.memspace - - raise ValueError(f"{self} doesn't have memspace") - - @dsl_user_op - def load(self, *, loc=None, ip=None) -> "TensorSSA": - """Load tensor elements as a vector. - - Loads all elements of the tensor into a vector representation, assuming the tensor - has a static shape and is in a memory space that supports load operations. - - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Vector representation of tensor elements - :rtype: TensorSSA - - :raises ValueError: If tensor has dynamic layout - :raises ValueError: If tensor memory space doesn't support load operations - """ - if not is_static(self.shape): - raise ValueError("dynamic layout doesn't support load") - - self._check_can_load_store() - - res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) - if self.element_type is Boolean: - assert ( - res_vect.type.element_type == T.i8() - ), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" - zeros = full_like(self, 0, Int8, loc=loc, ip=ip) - res_vect = arith.cmpi( - arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip - ) - return TensorSSA(res_vect, self.shape, self.element_type) - - @dsl_user_op - def store(self, data: "TensorSSA", *, loc=None, ip=None): - """Store vector data into tensor. - - Stores vector data into the tensor, assuming matching shapes and a memory space - that supports store operations. - - :param data: Vector data to store into tensor - :type data: TensorSSA - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises ValueError: If tensor has dynamic layout - :raises ValueError: If tensor memory space doesn't support store operations - :raises ValueError: If data shape doesn't match tensor shape - """ - if not isinstance(data, TensorSSA): - raise ValueError(f"Expects TensorSSA, but got {type(data)}") - - if not is_static(self.shape): - raise ValueError("Dynamic layout doesn't support vectorized store") - - self._check_can_load_store() - - n_elems = size(self.shape, loc=loc, ip=ip) - if n_elems != size(data.shape, loc=loc, ip=ip): - raise ValueError( - f"lhs and rhs must have the same shape, but got {self.shape} and {data.shape}" - ) - - elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) - if cutlass_arith.is_narrow_precision(elem_mlir_type): - if elem_mlir_type.width * n_elems % 32 != 0: - raise ValueError( - f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" - ) - - # Implicit upcast to wider type - new_data = self._cvt_to_dest(data, loc=loc, ip=ip) - - return _cute_ir.memref_store_vec( - new_data, self.value, row_major=True, loc=loc, ip=ip - ) - - @dsl_user_op - def fill(self, value: Numeric, *, loc=None, ip=None) -> None: - """Fill tensor with a constant value. - - Fills all elements of the tensor with the specified value, assuming static size - and supported memory space. - - :param value: Value to fill tensor with - :type value: Union[int, float] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises NotImplementedError: If tensor has dynamic size - - **Examples:** - - .. code-block:: python - - # Create tensor from numpy array - b = np.random.randn(4, 8).astype(np.float32) - tensor = from_dlpack(b) - - # Fill tensor with constant value - tensor.fill(0.5) # All elements become 0.5 - """ - self._check_can_load_store() - - sz = size(self, loc=loc, ip=ip) - if type(sz) is not int: - raise NotImplementedError(f"dynamic size is not supported: {self.type}") - - # Should we cast to destination type even with narrow cast? - dst_type = self.element_type - value = dst_type(value) - - self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip) - - def _check_can_load_store(self): - if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in ( - AddressSpace.rmem, - AddressSpace.smem, - AddressSpace.gmem, - AddressSpace.generic, - ): - raise ValueError(f"{self} doesn't support load and store") - - def _check_can_dereference(self): - # Check for sub-byte types and raise error if needed - if self.element_type.width % 8 != 0 and self.element_type is not Boolean: - raise ValueError( - f"Sub-byte scalar dereference not supported for type {self.element_type}" - ) - - -@dsl_user_op -def print_tensor( - tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None -): - """Print content of the tensor in human readable format. - - Outputs the tensor data in a structured format showing both metadata - and the actual data values. The output includes tensor type information, - layout details, and a formatted array representation of the values. - - :param tensor: The tensor to print - :type tensor: Tensor - :param verbose: If True, includes additional debug information in the output - :type verbose: bool - :param loc: Source location where it's called, defaults to None - :type loc: source location, optional - :param ip: Insertion pointer for IR generation, defaults to None - :type ip: insertion pointer, optional - :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing - - **Example output:** - - .. code-block:: text - - tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= - [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], - [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], - [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], - [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], - [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], - [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], - [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], - [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) - """ - if isinstance(tensor, TensorSSA): - tmp = make_fragment(tensor.shape, tensor.dtype) - tmp.store(tensor) - tensor = tmp - - if not isinstance(tensor.type, _cute_ir.MemRefType): - raise NotImplementedError( - f"printing {tensor} is not supported because it doesn't support trivial dereferencing. " - f"Coordinate Tensor will be supported in the future." - ) - - tensor._check_can_load_store() # type: ignore - - if tensor.element_type.is_integer: - signed = tensor.element_type.signed - else: - signed = False - - _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) - - -#################################################################################################### -# -# Core API -# -#################################################################################################### - - -# -# Utilties -# - - -@lru_cache_ir() -def is_integer(a) -> bool: - """Check if an object is static integer or dynamic integer""" - return isinstance(a, (int, Integer)) or ( - isinstance(a, ir.Value) - and isinstance(a.type, (ir.IntegerType, _cute_ir.ConstrainedIntType)) - ) - - -def is_valid_leaf(a) -> bool: - """ - Returns whether `a` has a type that is valid for a CuTe tuple's leaf. - """ - return ( - is_integer(a) - or (a is None) - or isinstance(a, (ScaledBasis, Layout, ComposedLayout)) - ) - - -def is_int_tuple(a) -> bool: - if isinstance(a, tuple): - return all([is_int_tuple(x) for x in a]) - else: - return is_integer(a) - - -def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: - """Check if a value is statically known at compile time. - - In CuTe, static values are those whose values are known at compile time, - as opposed to dynamic values which are only known at runtime. - - :param x: The value to check - :type x: Union[ir.Type, ir.Value, XTuple] - :return: True if the value is static, False otherwise - :rtype: bool - :raises TypeError: If an unsupported type is provided - """ - if isinstance(x, ir.Type): - return _cute_ir.is_static(x) - elif isinstance(x, tuple): - return all(is_static(a) for a in x) - # Can it be a static int? - elif isinstance(x, Numeric): - return False - elif is_dynamic_expression(x): - return _cute_ir.is_static(x.type) - elif isinstance(x, (bool, int, float)) or x is None: - return True - elif isinstance(x, ScaledBasis): - return x.is_static() - else: - raise TypeError(f"unsupported type {x}") - - -def has_underscore(a: XTuple) -> bool: - if type(a) is tuple: - return any([has_underscore(x) for x in a]) - else: - return a is None - - -def has_scaled_basis(a: XTuple) -> bool: - """Check if a tuple or its nested elements contain ScaledBasis objects. - - ScaledBasis objects are fundamental components in CuTe layouts, - representing the basis vectors of coordinate systems. - - :param a: The tuple to check - :type a: XTuple - :return: True if the tuple contains ScaledBasis objects, False otherwise - :rtype: bool - """ - if type(a) is tuple: - return any([has_scaled_basis(x) for x in a]) - else: - return isinstance(a, ScaledBasis) - - -def _tuple_str(t: tuple) -> str: - """ - Constructs a string representation of a python tuple without calling __repr__ on its elements. - """ - - def construct_inner_str(t) -> str: - if not isinstance(t, tuple): - return pretty_str(t) - res = "" - l = len(t) - for i in range(l): - res += pretty_str(t[i]) - if i < l - 1: - res += "," - return res - - res = "(" + construct_inner_str(t) + ")" - return res - - -def pretty_str(arg) -> str: - """ - Constructs a concise readable pretty string. - """ - if isinstance(arg, tuple): - # _tuple_str for tuples - return _tuple_str(arg) - elif arg is None: - # We interpret None as underscores for slicers - return "_" - else: - # Fallback to __str__ - return arg.__str__() - - -@dsl_user_op -def printf(*args, loc=None, ip=None) -> None: - """ - Print a value or a list of values. - - It supports c-style printf format as well: - - .. code-block:: python - - a = cute.make_layout(shape=(10, 10), stride=(10, 1)) - b = cutlass.Float32(1.234) - cute.printf(a, b) - cute.printf("a={}, b={}", a, b) - cute.printf("a={}, b=%.2f", a, b) - - :param args: List of values to print - :type args: list - :param loc: Source location where it's called, defaults to None - :type loc: source location, optional - :param ip: Insertion pointer, defaults to None - :type ip: insertion pointer, optional - :raises ValueError: If no arguments are provided or if an unsupported argument type is passed - """ - - if len(args) == 0: - raise ValueError("expects at least one argument to print") - - if isinstance(args[0], str): - fmt = args[0] + "\n" - args = args[1:] - else: - fmt = "{}" + ", {}" * (len(args) - 1) + "\n" - - def process_arg(arg): - arg0 = arg.value if isinstance(arg, Numeric) else arg - - if isinstance(arg0, ir.Value): - return arg0 - elif isinstance(arg0, bool): - return const(arg0, Boolean) - elif isinstance(arg0, int): - return const(arg0, Int32) - elif isinstance(arg0, float): - return const(arg0, Float32) - elif has_underscore(arg0): - # Assume it's a coordinate - return _pack_coord(arg0) - elif has_scaled_basis(arg0): - # Assume it's a stride - return _pack_stride(arg0) - elif isinstance(arg0, tuple): - # Assume it's an int_tuple - return _pack_int_tuple(arg0) - elif isinstance(arg0, (_Tensor, _Pointer)): - return arg0.value - else: - raise TypeError(f"unsupported argument type in printf, got {type(arg)}") - - args = [process_arg(a) for a in args] - _cute_ir.print_(args, fmt=fmt, loc=loc, ip=ip) - - -@dsl_user_op -def front(input, *, loc=None, ip=None): - """Recursively get the first element of input. - - This function traverses a hierarchical structure (like a layout or tensor) - and returns the first element at the deepest level. It's particularly useful - for accessing the first stride value in a layout to determine properties like - majorness. - - :param input: The hierarchical structure to traverse - :type input: Union[Tensor, Layout, Stride] - :param loc: Source location where it's called, defaults to None - :type loc: source location, optional - :param ip: Insertion pointer for IR generation, defaults to None - :type ip: insertion pointer, optional - :return: The first element at the deepest level of the input structure - :rtype: Union[int, float, bool, ir.Value] - """ - if rank(input) == 1 and depth(input) == 0: - return input - else: - return front(get(input, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) - - -@dsl_user_op -def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: - """ - Check whether a mode in stride is the major mode. - """ - first_stride = front(get(stride, mode=[mode], loc=loc, ip=ip), loc=loc, ip=ip) - if is_dynamic_expression(first_stride): - return False - return True if first_stride == 1 else False - - -def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: - """ - Find the leading dimension of a shape and stride. - - :param shape: The shape of the tensor or layout - :type shape: Shape - :param stride: The stride of the tensor or layout - :type stride: Stride - :return: The leading dimension index or indices - :rtype: Union[int, Tuple[int, ...], None] - - The return value depends on the stride pattern: - - * If a single leading dimension is found, returns an integer index - * If nested leading dimensions are found, returns a tuple of indices - * If no leading dimension is found, returns None - """ - - def pred_fn(val, pos): - # skip dynamic values which can't be compared - # find the candidate target val, stride at this position is 1 - if (not is_dynamic_expression(val)) and (val == 1): - # extract the shape at this position - mode = [pos] if isinstance(pos, int) else list(pos) - s = get(shape, mode) - if is_dynamic_expression(s) or s != 1: - # shape at this position is dynamic value or not 1 - # we found the leading dimension - return True - return False - - return find_if(stride, pred_fn=pred_fn) - - -@dsl_user_op -def find_if( - t: Union[tuple, ir.Value, int], - pred_fn: Callable[[int, Tuple[int, ...]], bool], - *, - loc=None, - ip=None, -) -> Union[int, Tuple[int, ...], None]: - """Find the first position in t where pred_fn(val, pos) returns True. - - :param t: The search space - :type t: Union[tuple, ir.Value, int] - :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. - It takes the current leaf value and position, returns True if the value or position is satisfied. - :type pred_fn: Callable[[int, Tuple[int, ...]], bool] - :return: Index if found at top level, tuple of indices showing nested position, or None if not found - :rtype: Union[int, Tuple[int, ...], None] - - **Examples:** - - .. code-block:: python - - # Find the first position of x in t - t = (3, 4) - find_if(t, pred_fn=lambda val, pos: val == x) - - .. code-block:: python - - # find the leading dimension - shape = (3, 4) - stride = (4, 1) - # Find value 1 in stride where the corresponding shape is not 1 - def pred_fn(val, pos): - mode = [pos] if isinstance(pos, int) else list(pos) - return val == 1 and get(shape, mode) != 1 - find_if(stride, pred_fn=pred_fn) - """ - - def _find_if_impl(curr, pos, *, loc=None, ip=None): - if isinstance(curr, tuple): - # Recursively search nested tuple - for i in range(rank(curr)): - sub_curr = get(curr, mode=[i], loc=loc, ip=ip) - sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) - res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) - if res_pos is not None: - return res_pos - else: - # For leaf values, check if it matches x - if pred_fn(curr, pos): - return pos - return None - - def _check_pred_fn(): - if not callable(pred_fn): - raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") - signature = inspect.signature(pred_fn) - if len(signature.parameters) != 2: - raise ValueError( - f"pred_fn must have two parameters (value, pos), but got {len(signature.parameters)}" - ) - - _check_pred_fn() - - for i in range(rank(t)): - curr = get(t, mode=[i], loc=loc, ip=ip) - res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) - if res_pos is not None: - return res_pos - return None - - -@dsl_user_op -def find( - t: Union[tuple, ir.Value, int], - x: int, - *, - loc=None, - ip=None, -) -> Union[int, Tuple[int, ...], None]: - """Find the first position of a value ``x`` in a hierarchical structure ``t``. - - Searches for the first occurrence of x in t, optionally excluding positions - where a comparison value matches. The search can traverse nested structures - and returns either a single index or a tuple of indices for nested positions. - - :param t: The search space - :type t: Union[tuple, ir.Value, int] - :param x: The static integer x to search for - :type x: int - :return: Index if found at top level, tuple of indices showing nested position, or None if not found - :rtype: Union[int, Tuple[int, ...], None] - """ - if not isinstance(x, int): - raise TypeError(f"find() requires a static x to search for, but got {x}") - - def pred_fn(val, pos): - # Skip dynamic values which can't be compared - return not is_dynamic_expression(val) and val == x - - return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) - - -def transform_leaf(f, *args): - """ - Apply a function to the leaf nodes of nested tuple structures. - - This function traverses nested tuple structures in parallel and applies the function f - to corresponding leaf nodes. All input tuples must have the same nested structure. - - :param f: Function to apply to leaf nodes - :type f: Callable - :param args: One or more nested tuple structures with matching profiles - :return: A new nested tuple with the same structure as the inputs, but with leaf values transformed by f - :raises TypeError: If the input tuples have different nested structures - - Example: - - .. code-block:: python - - >>> transform_leaf(lambda x: x + 1, (1, 2)) - (2, 3) - >>> transform_leaf(lambda x, y: x + y, (1, 2), (3, 4)) - (4, 6) - >>> transform_leaf(lambda x: x * 2, ((1, 2), (3, 4))) - ((2, 4), (6, 8)) - """ - if all(isinstance(t, tuple) for t in args): - return tuple(transform_leaf(f, *_args) for _args in zip(*args)) - elif all(not isinstance(t, tuple) for t in args): - return f(*args) - else: - raise TypeError(f"profile of input tuples doesn't match: {args}") - - -@dsl_user_op -def assume(src, divby=None, *, loc=None, ip=None): - if divby is None: - return src - - if isinstance(src, Integer): - width = type(src).width - src_val = src.ir_value() - else: - width = src.type.width - src_val = src - - res_ty = _cute_ir.ConstrainedIntType.get(divby, width) - assumed_val = _cute_ir.assume(res_ty, src_val, loc=loc, ip=ip) - return type(src)(IntValue(_pack_int_tuple(assumed_val, loc=loc, ip=ip))) - - -@dsl_user_op -def make_swizzle(b, m, s, *, loc=None, ip=None): - # canonicalize to <0, 4, 3> for identity swizzle (as compiler assumes <0, 4, 3>) - if b == 0: - m, s = 4, 3 - ty = ir.Type.parse(f'!cute.swizzle<"S<{b},{m},{s}>">') - return Swizzle(_cute_ir.static(ty, loc=loc, ip=ip)) - - -# -# Tuple API (also used by layouts and tensors) -# - - -def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: - """Returns the depth (nesting level) of a tuple, layout, or tensor. - - The depth of a tuple is the maximum depth of its elements plus 1. - For an empty tuple, the depth is 1. For layouts and tensors, the depth - is determined by the depth of their shape. For non-tuple values (e.g., integers), - the depth is considered 0. - - :param a: The object whose depth is to be determined - :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] - :return: The depth of the input object - :rtype: int - - Example: - - .. code-block:: python - - >>> depth(1) - 0 - >>> depth((1, 2)) - 1 - >>> depth(((1, 2), (3, 4))) - 2 - """ - if type(a) is tuple: - if not a: - return 1 - return max(depth(x) for x in a) + 1 - elif isinstance(a, (Layout, ComposedLayout, Tensor)): - return depth(a.shape) - else: - return 0 - - -@lru_cache_ir() -def rank(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: - """Returns the rank (dimensionality) of a tuple, layout, or tensor. - - The rank of a tuple is its length. For layouts and tensors, the rank is - determined by the rank of their shape. For non-tuple values (e.g., integers), - the rank is considered 1 for convenience. - - :param a: The object whose rank is to be determined - :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] - :return: The rank of the input object - :rtype: int - - This function is used in layout algebra to determine the dimensionality - of tensors and layouts for operations like slicing and evaluation. - """ - if isinstance(a, tuple): - return len(a) - elif isinstance(a, (Layout, ComposedLayout, Tensor)): - return rank(a.shape) - elif depth(a) == 0: - return 1 - else: - raise TypeError(f"unsupported type in rank, got {type(a)}") - - -def is_congruent( - a: Union[XTuple, Layout, ComposedLayout, Tensor], - b: Union[XTuple, Layout, ComposedLayout, Tensor], -) -> bool: - """ - Returns whether a is congruent to b. - - Congruence is an equivalence relation between hierarchical structures. - - Two objects are congruent if: - * They have the same rank, AND - * They are both non-tuple values, OR - * They are both tuples AND all corresponding elements are congruent. - - Congruence requires type matching at each level -- scalar values match with - scalar values, and tuples match with tuples of the same rank. - - :param a: First object to compare - :type a: Union[XTuple, Layout, ComposedLayout, Tensor] - :param b: Second object to compare - :type b: Union[XTuple, Layout, ComposedLayout, Tensor] - :return: True if a and b are congruent, False otherwise - :rtype: bool - """ - if isinstance(a, (Layout, ComposedLayout, Tensor)): - a = a.shape - if isinstance(b, (Layout, ComposedLayout, Tensor)): - b = b.shape - if isinstance(a, tuple) and isinstance(b, tuple): - return (len(a) == len(b)) and all(is_congruent(x, y) for x, y in zip(a, b)) - if isinstance(a, tuple) or isinstance(b, tuple): - return False - return True - - -def is_weakly_congruent( - a: Union[XTuple, Layout, ComposedLayout, Tensor], - b: Union[XTuple, Layout, ComposedLayout, Tensor], -) -> bool: - """ - Returns whether a is weakly congruent to b. - - Weak congruence is a partial order on hierarchical structures. - - Object X is weakly congruent to object Y if: - * X is a non-tuple value, OR - * X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent. - - Weak congruence allows scalar values to match with tuples, making it useful - for determining whether an object has a hierarchical structure "up to" another. - - :param a: First object to compare - :type a: Union[XTuple, Layout, ComposedLayout, Tensor] - :param b: Second object to compare - :type b: Union[XTuple, Layout, ComposedLayout, Tensor] - :return: True if a and b are weakly congruent, False otherwise - :rtype: bool - """ - if isinstance(a, (Layout, ComposedLayout, Tensor)): - a = a.shape - if isinstance(b, (Layout, ComposedLayout, Tensor)): - b = b.shape - if not isinstance(a, tuple): - return True - if isinstance(a, tuple) and isinstance(b, tuple): - return (len(a) == len(b)) and all( - is_weakly_congruent(x, y) for x, y in zip(a, b) - ) - if isinstance(a, tuple) or isinstance(b, tuple): - return False - return True - - -@overload -def get(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... -@overload -def get(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... -@overload -def get(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... -@overload -def get(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... -@overload -def get(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... -@overload -def get(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... -@overload -def get(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... - - -@dsl_user_op -def get(input, mode: List[int], *, loc=None, ip=None): - """Extract a specific element or sub-layout from a layout or tuple. - - This function recursively traverses the input according to the mode indices, - extracting the element at the specified path. For layouts, this operation - corresponds to extracting a specific sub-layout. - - :param input: The input layout or tuple to extract from - :type input: Layout, ComposedLayout, tuple - :param mode: Indices specifying the path to traverse for extraction - :type mode: List[int] - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The extracted element or sub-layout - :rtype: Layout, ComposedLayout, or element type - :raises ValueError: If any index in mode is out of range - :raises TypeError: If mode contains non-integer elements or if input has unsupported type - - :postcondition: ``get(t, mode=find(x,t)) == x if find(x,t) != None else True`` - - **Examples:** - - .. code-block:: python - - layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512)) - sub_layout = get(layout, mode=[0, 1]) # 8:4 - sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0) - """ - # Empty mode returns input and terminates the recursive call - if not mode: - return input - - if rank(input) <= mode[0]: - raise ValueError( - f"elements in mode must be less than rank({input}), got {mode}" - ) - - if depth(input) == 0: - return input - elif isinstance(input, tuple): - if not isinstance(mode[0], int): - raise TypeError( - f"invalid element in mode, expects int, got {type(mode[0])}" - ) - return get(input[mode[0]], mode=mode[1:]) - else: - if not isinstance(input, (Layout, ComposedLayout)): - raise TypeError(f"unsupported type of input, got {type(input)}") - return _cute_ir.get( - input.type.get_op_res_type(mode=mode), input, mode=mode, loc=loc, ip=ip - ) - - -@overload -def select(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... -@overload -def select(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... -@overload -def select(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... -@overload -def select(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... -@overload -def select(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... -@overload -def select(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... -@overload -def select(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... - - -@dsl_user_op -def select(input, mode: List[int], *, loc=None, ip=None): - """Select modes from input. - - :param input: Input to select from - :type input: Layout, ComposedLayout, tuple - :param mode: Indices specifying which dimensions or elements to select - :type mode: List[int] - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: A new instance with selected dimensions/elements - :rtype: Layout, ComposedLayout, tuple - :raises ValueError: If any index in mode is out of range - :raises TypeError: If the input type is invalid - - **Examples:** - - .. code-block:: python - - # Select specific dimensions from a layout - layout = make_layout((4, 8, 16), stride=(32, 4, 1)) - selected = select(layout, mode=[0, 2]) # Select mode 0 and mode 2 - # Result: (4, 16):(32, 1) - - # Select elements from a tuple - t = (1, 2, 3, 4, 5) - selected = select(t, mode=[0, 2, 4]) # Select mode 0, mode 2, and mode 4 - # Result: (1, 3, 5) - """ - if any((not isinstance(i, int)) or (i >= rank(input)) for i in mode): - raise ValueError( - f"invalid mode element for input of rank {rank(input)}, got {mode=}" - ) - - if isinstance(input, tuple): - return tuple(input[i] for i in mode) - - if not isinstance(input, (Layout, ComposedLayout)): - raise TypeError(f"unsupported type of input, got {type(input)}") - - return _cute_ir.select(input, mode=mode, loc=loc, ip=ip) - - -@overload -def group_modes(input: Shape, begin: int, end: int, *, loc=None, ip=None) -> Shape: ... -@overload -def group_modes( - input: Stride, begin: int, end: int, *, loc=None, ip=None -) -> Stride: ... -@overload -def group_modes(input: Coord, begin: int, end: int, *, loc=None, ip=None) -> Coord: ... -@overload -def group_modes( - input: IntTuple, begin: int, end: int, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def group_modes(input: Tile, begin: int, end: int, *, loc=None, ip=None) -> Tile: ... -@overload -def group_modes( - input: Layout, begin: int, end: int, *, loc=None, ip=None -) -> Layout: ... -@overload -def group_modes( - input: ComposedLayout, begin: int, end: int, *, loc=None, ip=None -) -> ComposedLayout: ... -@overload -def group_modes( - input: Tensor, begin: int, end: int, *, loc=None, ip=None -) -> Tensor: ... - - -@dsl_user_op -def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): - """Group modes of a hierarchical tuple or layout into a single mode. - - This function groups a range of modes from the input object into a single mode, - creating a hierarchical structure. For tuples, it creates a nested tuple containing - the specified range of elements. For layouts and other CuTe objects, it creates - a hierarchical representation where the specified modes are grouped together. - - :param input: Input object to group modes from (layout, tuple, etc.) - :type input: Layout, ComposedLayout, tuple, Shape, Stride, etc. - :param beg: Beginning index of the range to group (inclusive) - :type beg: int - :param end: Ending index of the range to group (exclusive) - :type end: int - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: A new object with the specified modes grouped - :rtype: Same type as input with modified structure - - **Examples:** - - .. code-block:: python - - # Group modes in a tuple - t = (2, 3, 4, 5) - grouped = group_modes(t, 1, 3) # (2, (3, 4), 5) - - # Group modes in a layout - layout = make_layout((2, 3, 4, 5)) - grouped_layout = group_modes(layout, 1, 3) # Layout with shape (2, (3, 4), 5) - - # Group modes in a shape - shape = make_shape(2, 3, 4, 5) - grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5) - """ - if depth(input) == 0 and is_integer(input): - return (input,) - if isinstance(input, tuple): - return (*input[:begin], (input[begin:end]), *input[end:]) - return _cute_ir.group_modes( - input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip - ) - - -@overload -def slice_(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... -@overload -def slice_(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... -@overload -def slice_(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... -@overload -def slice_(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... -@overload -def slice_(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... -@overload -def slice_(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... -@overload -def slice_( - src: ComposedLayout, coord: Coord, *, loc=None, ip=None -) -> ComposedLayout: ... -@overload -def slice_(src: Tensor, coord: Coord, *, loc=None, ip=None) -> Tensor: ... - - -@dsl_user_op -def slice_(src, coord: Coord, *, loc=None, ip=None): - """Perform a slice operation on a source object using the given coordinate. - - This function implements CuTe's slicing operation which extracts a subset of elements - from a source object (tensor, layout, etc.) based on a coordinate pattern. The slice - operation preserves the structure of the source while selecting specific elements. - - :param src: Source object to be sliced (tensor, layout, tuple, etc.) - :type src: Union[Tensor, Layout, IntTuple, Value] - :param coord: Coordinate pattern specifying which elements to select - :type coord: Coord - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A new object containing the sliced elements - :rtype: Union[Tensor, Layout, IntTuple, tuple] - :raises ValueError: If the coordinate pattern is incompatible with source - - **Examples:** - - .. code-block:: python - - # Layout slicing - layout = make_layout((4,4)) - - # Select 1st index of first mode and keep all elements in second mode - sub_layout = slice_(layout, (1, None)) - - .. code-block:: python - - # Basic tensor slicing - tensor = make_tensor(...) # Create a 2D tensor - - # Select 1st index of first mode and keep all elements in second mode - sliced = slice_(tensor, (1, None)) - - .. code-block:: python - - # Select 2nd index of second mode and keep all elements in first mode - sliced = slice_(tensor, (None, 2)) - - Note: - - `None` represents keeping all elements in that mode - - Slicing preserves the layout/structure of the original object - - Can be used for: - * Extracting sub-tensors/sub-layouts - * Creating views into data - * Selecting specific patterns of elements - """ - - def lift_slice(a, b): - if isinstance(a, tuple): - if (not isinstance(b, tuple)) or (len(a) != len(b)): - raise ValueError("coord must be weakly congruent to src in slice_") - return reduce( - lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(a, b)), () - ) - elif a is None: - return (b,) - else: - return () - - if is_integer(src) or isinstance(src, tuple): - if isinstance(coord, tuple): - if (not isinstance(src, tuple)) or (len(coord) != len(src)): - raise ValueError("coord must be weakly congruent to src in slice_") - return reduce( - lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(coord, src)), () - ) - elif coord is None: - return src - else: - return () - - res_type = None - if isinstance(src, Tensor): - res_type = src.element_type - src = src.value - coord_val = _pack_coord(coord, loc=loc, ip=ip) - res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res - - -@overload -def dice(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... -@overload -def dice(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... -@overload -def dice(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... -@overload -def dice(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... -@overload -def dice(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... -@overload -def dice(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... -@overload -def dice(src: ComposedLayout, coord: Coord, *, loc=None, ip=None) -> ComposedLayout: ... - - -@dsl_user_op -@lru_cache_ir() -def dice(src, dicer, *, loc=None, ip=None): - """Keep modes in input when it is paired with an integer in dicer. - - This function performs dicing operation on the input based on the dicer coordinate. - Dicing is a fundamental operation in CuTe that allows selecting specific modes from - a tensor or layout based on a coordinate pattern. - - :param dicer: A static coordinate indicating how to dice the input - :type dicer: Coord - :param input: The operand to be diced on - :type input: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: The diced result with selected modes from the input - :rtype: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] - :raises TypeError: If dicer has an unsupported type - :raises ValueError: If input is not provided - - **Examples:** - - .. code-block:: python - - # Basic dicing of a layout - layout = make_layout((32,16,8)) - - # Keep only first and last modes - diced = dice((1,None,1), layout) - - Note: - - The dicer coordinate must be static - - Use underscore (_) to remove a mode - """ - if not is_static(dicer): - raise ValueError(f"expects dicer to be static, but got {dicer}") - - def lift_dice(a, b): - if isinstance(a, tuple): - if (not isinstance(b, tuple)) or (len(a) != len(b)): - raise ValueError("dicer must be weakly congruent to input in dice") - return reduce( - lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(a, b)), () - ) - elif a is None: - return () - else: - return (b,) - - if is_integer(src) or isinstance(src, tuple): - if isinstance(dicer, tuple): - if (not isinstance(src, tuple)) or (len(dicer) != len(src)): - raise ValueError("dicer must be weakly congruent to src in dice") - return reduce( - lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(dicer, src)), () - ) - elif dicer is None: - return () - else: - return src - - dicer_val = _pack_coord(dicer, loc=loc, ip=ip) - return _cute_ir.dice(src, dicer_val.type.attribute, loc=loc, ip=ip) - - -def wrap(x) -> tuple: - """ - Wraps the input into a tuple if not a tuple. - """ - if isinstance(x, tuple): - return x - return (x,) - - -def _extend(func, input, elem, up_to_rank, loc, ip): - if input is None: - raise ValueError(f"No input provided for input") - - if isinstance(input, (Layout, ComposedLayout)): - if elem is None: - elem = make_layout(1) - elif not isinstance(elem, Layout): - raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") - N = rank(input) + 1 if up_to_rank is None else up_to_rank - return func(N, input, elem, loc=loc, ip=ip) - - if is_valid_leaf(input) or isinstance(input, tuple): - if elem is None: - elem = 1 - if (not isinstance(elem, tuple)) and (not is_valid_leaf(elem)): - raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") - - input = wrap(input) - repeat_cnt = 1 if up_to_rank is None else up_to_rank - rank(input) - if repeat_cnt == 0: - return input - elif repeat_cnt < 0: - raise ValueError(f"up_to_rank must be >= rank(input)") - else: - if func is _cute_ir.prepend_to_rank: - return (elem,) * repeat_cnt + input - else: - return input + (elem,) * repeat_cnt - - raise TypeError(f"invalid type for input, got {type(input)}") - - -@overload -def prepend( - input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None -) -> Shape: ... -@overload -def prepend( - input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None -) -> Stride: ... -@overload -def prepend( - input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None -) -> Coord: ... -@overload -def prepend( - input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def prepend(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... -@overload -def prepend( - input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None -) -> Layout: ... -@overload -def prepend( - input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): - """Extend input to rank up_to_rank by prepending elem in front of input. - - This function extends the input object by prepending elements to reach a desired rank. - It supports various CuTe types including shapes, layouts, tensors etc. - - :param input: Source to be prepended to - :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] - :param elem: Element to prepend to input - :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] - :param up_to_rank: The target rank after extension, defaults to None - :type up_to_rank: Union[None, int], optional - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint] - :return: The extended result with prepended elements - :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] - :raises ValueError: If up_to_rank is less than input's current rank - :raises TypeError: If input or elem has unsupported type - - **Examples:** - - .. code-block:: python - - # Prepend to a Shape - shape = (4,4) - prepend(shape, 2) # Returns (2,4,4) - - # Prepend to a Layout - layout = make_layout((8,8)) - prepend(layout, make_layout((2,))) # Returns (2,8,8):(1,1,8) - - # Prepend with target rank - coord = (1,1) - prepend(coord, 0, up_to_rank=4) # Returns (0,0,1,1) - """ - return _extend(_cute_ir.prepend_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) - - -@overload -def append( - input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None -) -> Shape: ... -@overload -def append( - input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None -) -> Stride: ... -@overload -def append( - input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None -) -> Coord: ... -@overload -def append( - input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def append(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... -@overload -def append( - input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None -) -> Layout: ... -@overload -def append( - input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): - """Extend input to rank up_to_rank by appending elem to the end of input. - - This function extends the input object by appending elements to reach a desired rank. - It supports various CuTe types including shapes, layouts, tensors etc. - - :param input: Source to be appended to - :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] - :param elem: Element to append to input - :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] - :param up_to_rank: The target rank after extension, defaults to None - :type up_to_rank: Union[None, int], optional - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint] - :return: The extended result with appended elements - :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] - :raises ValueError: If up_to_rank is less than input's current rank - :raises TypeError: If input or elem has unsupported type - - **Examples:** - - .. code-block:: python - - # Append to a Shape - shape = (4,4) - append(shape, 2) # Returns (4,4,2) - - # Append to a Layout - layout = make_layout((8,8)) - append(layout, make_layout((2,))) # Returns (8,8,2):(1,8,1) - - # Append with target rank - coord = (1,1) - append(coord, 0, up_to_rank=4) # Returns (1,1,0,0) - - Note: - - The function preserves the structure of the input while extending it - - Can be used to extend tensors, layouts, shapes and other CuTe types - - When up_to_rank is specified, fills remaining positions with elem - - Useful for tensor reshaping and layout transformations - """ - return _extend(_cute_ir.append_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) - - -@dsl_user_op -def prepend_ones( - t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None -) -> Tensor: - return make_tensor( - t.iterator, prepend(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip - ) - - -@dsl_user_op -def append_ones( - t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None -) -> Tensor: - return make_tensor( - t.iterator, append(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip - ) - - -def repeat_like(x, target): - """Creates an object congruent to target and filled with x. - - This function recursively creates a nested tuple structure that matches the structure - of the target, with each leaf node filled with the value x. - - :param x: The value to fill the resulting structure with - :type x: Any - :param target: The structure to mimic - :type target: Union[tuple, Any] - :return: A structure matching target but filled with x - :rtype: Union[tuple, Any] - - **Examples:** - - .. code-block:: python - - repeat_like(0, (1, 2, 3)) # Returns (0, 0, 0) - repeat_like(1, ((1, 2), 3)) # Returns ((1, 1), 1) - repeat_like(2, 5) # Returns 2 - """ - if not isinstance(target, tuple): - return x - if not target: - return () - if len(target) == 1: - return (repeat_like(x, target[0]),) - return tuple(repeat_like(x, t) for t in target) - - -def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: - """Flattens a potentially nested tuple structure into a flat tuple. - - This function recursively traverses the input structure and flattens it into - a single-level tuple, preserving the order of elements. - - :param a: The structure to flatten - :type a: Union[IntTuple, Coord, Shape, Stride] - :return: A flattened tuple containing all elements from the input - :rtype: tuple - - **Examples:** - - .. code-block:: python - - flatten_to_tuple((1, 2, 3)) # Returns (1, 2, 3) - flatten_to_tuple(((1, 2), 3)) # Returns (1, 2, 3) - flatten_to_tuple((1, (2, (3,)))) # Returns (1, 2, 3) - """ - if not isinstance(a, tuple): - return wrap(a) - else: - return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) - - -@overload -def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ... -@overload -def flatten(a: Tensor) -> Tensor: ... -@overload -def flatten(a: Layout) -> Layout: ... - - -def flatten(a): - """Flattens a CuTe data structure into a simpler form. - - For tuples, this function flattens the structure into a single-level tuple. - For layouts, it returns a new layout with flattened shape and stride. - For tensors, it returns a new tensor with flattened layout. - For other types, it returns the input unchanged. - - :param a: The structure to flatten - :type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor] - :return: The flattened structure - :rtype: Union[tuple, Any] - - **Examples:** - - .. code-block:: python - - flatten((1, 2, 3)) # Returns (1, 2, 3) - flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4) - flatten(5) # Returns 5 - flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride)) - flatten(Tensor(layout)) # Returns Tensor(flatten(layout)) - - """ - if isinstance(a, Tensor): - return make_tensor(a.iterator, flatten(a.layout)) - elif isinstance(a, Layout): - return make_layout(flatten(a.shape), stride=flatten(a.stride)) - elif isinstance(a, tuple): - return flatten_to_tuple(a) - else: - return a - - -def unflatten( - sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]], profile: XTuple -) -> XTuple: - """Unflatten a flat tuple into a nested tuple structure according to a profile. - - This function transforms a flat sequence of elements into a nested tuple structure - that matches the structure defined by the profile parameter. It traverses the profile - structure and populates it with elements from the sequence. - - sequence must be long enough to fill the profile. Raises RuntimeError if it is not. - - :param sequence: A flat sequence of elements to be restructured - :type sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]] - :param profile: A nested tuple structure that defines the shape of the output - :type profile: XTuple - :return: A nested tuple with the same structure as profile but containing elements from sequence - :rtype: XTuple - - Example: - >>> unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) - ((1, 2), (3, 4)) - """ - - def _make_generator(): - for element in sequence: - yield element - - xs = _make_generator() - return transform_leaf(lambda _: next(xs), profile) - - -@dsl_user_op -def elem_less( - lhs: Union[Shape, IntTuple, Coord], - rhs: Union[Shape, IntTuple, Coord], - *, - loc=None, - ip=None, -): - lhs_val = _pack_coord(lhs, loc=loc, ip=ip) - rhs_val = _pack_coord(rhs, loc=loc, ip=ip) - return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) - - -@overload -def filter_zeros( - input: Layout, *, target_profile=None, loc=None, ip=None -) -> Layout: ... -@overload -def filter_zeros( - input: Tensor, *, target_profile=None, loc=None, ip=None -) -> Tensor: ... - - -@dsl_user_op -def filter_zeros(input, *, target_profile=None, loc=None, ip=None): - """Filter out zeros from a layout or tensor. - - This function removes zero-stride dimensions from a layout or tensor. - Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md - for more layout algebra operations. - - :param input: The input layout or tensor to filter - :type input: Layout or Tensor - :param target_profile: Target profile for the filtered result, defaults to None - :type target_profile: optional - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The filtered layout or tensor with zeros removed - :rtype: Layout or Tensor - :raises TypeError: If input is not a Layout or Tensor - """ - if not isinstance(input, (Layout, Tensor)): - raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") - if isinstance(input, Tensor): - input = input.value - return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip) - - -@dsl_user_op -def filter(input: Union[Layout, Tensor], *, loc=None, ip=None): - """Filter a layout or tensor. - - This function filters a layout or tensor according to CuTe's filtering rules. - - :param input: The input layout or tensor to filter - :type input: Layout or Tensor - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The filtered layout or tensor - :rtype: Layout or Tensor - :raises TypeError: If input is not a Layout or Tensor - """ - if not isinstance(input, (Layout, Tensor)): - raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") - if isinstance(input, _Tensor): - input = input.value - return _cute_ir.filter(input, loc=loc, ip=ip) - - -@dsl_user_op -def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): - """Return product of the given IntTuple or Shape. - - Computes the product of all elements in the input tuple or shape. - Returns static value if type is static. - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: Static product of IntTuple or Shape if static, otherwise a Value - :rtype: int or Value - :raises TypeError: If input is not an IntTuple or Shape - """ - if is_integer(a): - return a - if isinstance(a, tuple): - a_val = _pack_int_tuple(a, loc=loc, ip=ip) - res = _cute_ir.tuple_product(a_val, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - else: - raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") - - -@overload -def product_like( - a: IntTuple, target_profile: XTuple, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def product_like(a: Shape, target_profile: XTuple, *, loc=None, ip=None) -> Shape: ... - - -@dsl_user_op -def product_like( - a: Union[IntTuple, Shape], target_profile: XTuple, *, loc=None, ip=None -): - """Return product of the given IntTuple or Shape at leaves of `target_profile`. - - This function computes products according to the structure defined by target_profile. - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param target_profile: The profile that guides how products are computed - :type target_profile: XTuple - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The resulting tuple with products computed according to target_profile - :rtype: IntTuple or Shape - :raises TypeError: If inputs have incompatible types - :raises ValueError: If inputs have incompatible shapes - """ - # Perform product at leaf of `target_profile` - if not isinstance(target_profile, tuple): - return product(a, loc=loc, ip=ip) - else: - if not isinstance(a, tuple): - raise TypeError(f"expects `a` tuple but got {a}") - - if len(a) != len(target_profile): - raise ValueError(f"expects `a` and `guide` have the same rank") - - return tuple( - product_like(x, g, loc=loc, ip=ip) for x, g in zip(a, target_profile) - ) - - -@overload -def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: ... -@overload -def product_each(a: Shape, *, loc=None, ip=None) -> Shape: ... - - -@dsl_user_op -def product_each(a, *, loc=None, ip=None): - """Compute products for each component of the input. - - Returns a rank(a) tuple `result` such that get(result, mode=[i]) == product(get(a, mode=[i])) - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: A tuple containing products for each component - :rtype: tuple - :raises TypeError: If input is not an IntTuple or Shape - """ - if is_integer(a): - return a - if isinstance(a, tuple): - if not a: - return 1 - else: - a_val = _pack_int_tuple(a, loc=loc, ip=ip) - res = _cute_ir.tuple_product_each(a_val, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - else: - raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") - - -@dsl_user_op -def size( - a: Union[IntTuple, Shape, Layout, ComposedLayout, Tensor], - mode: List[int] = [], - *, - loc=None, - ip=None, -) -> Int: - """Return size of domain of layout or tensor. - - Computes the size (number of elements) in the domain of a layout or tensor. - For layouts, this corresponds to the shape of the coordinate space. - See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md - for more details on layout domains. - - :param a: The input object whose size to compute - :type a: IntTuple, Shape, Layout, ComposedLayout or Tensor - :param mode: List of mode(s) for size calculation. If empty, computes total size, defaults to [] - :type mode: list of int, optional - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: Static size of layout or tensor if static, otherwise a Value - :rtype: int or Value - :raises ValueError: If mode contains non-integer elements - """ - if any(not isinstance(m, int) for m in mode): - raise ValueError(f"expects integer elements in mode, but got {mode}") - - if isinstance(a, (TiledMma, TiledCopy)): - return a.size - a_val = None - if not isinstance(a, (Layout, ComposedLayout, Tensor)): - a_val = _pack_int_tuple(a, loc=loc, ip=ip) - elif isinstance(a, Tensor): - a_val = a.value - else: - a_val = a - - res = _cute_ir.size(a_val, mode=mode, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore - - -@dsl_user_op -def shape_div(lhs: Shape, rhs: Shape, *, loc=None, ip=None) -> Shape: - """Perform element-wise division of shapes. - - This function performs element-wise division between two shapes. - - :param lhs: Left-hand side shape - :type lhs: Shape - :param rhs: Right-hand side shape - :type rhs: Shape - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The result of element-wise division - :rtype: Shape - """ - lhs = _pack_shape(lhs, loc=loc, ip=ip) - rhs = _pack_shape(rhs, loc=loc, ip=ip) - res = _cute_ir.shape_div(lhs, rhs, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - - -@dsl_user_op -def ceil_div(input: Shape, tiler: Tiler, *, loc=None, ip=None) -> Shape: - """ - Compute the ceiling division of a target shape by a tiling specification. - - This function computes the number of tiles required to cover the target domain. - It is equivalent to the second mode of `zipped_divide(input, tiler)`. - - :param input: A tuple of integers representing the dimensions of the target domain. - :type input: Shape - :param tiler: The tiling specification. - :type tiler: Union[Layout, Shape, Tile] - :param loc: Optional location information for IR diagnostics. - :type loc: optional - :param ip: Optional instruction pointer or context for underlying IR functions. - :type ip: optional - :return: A tuple of integers representing the number of tiles required along each dimension, - i.e. the result of the ceiling division of the input dimensions by the tiler dimensions. - :rtype: Shape - - Example: - - .. code-block:: python - - import cutlass.cute as cute - @cute.jit - def foo(): - input = (10, 6) - tiler = (3, 4) - result = cute.ceil_div(input, tiler) - print(result) # Outputs: (4, 2) - """ - input_val = _pack_shape(input, loc=loc, ip=ip) - tiler_val = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.ceil_div(input=input_val, tiler=tiler_val, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - - -def round_up(a: IntTuple, b: IntTuple) -> IntTuple: - """ - Rounds up elements of a using elements of b. - """ - if isinstance(a, tuple): - if not a: - raise ValueError(f"inputs cannot be empty") - if not isinstance(b, tuple): - raise TypeError( - f"expects both inputs to be tuple, but got {type(a)} and {type(b)}" - ) - if rank(a) < rank(b): - raise ValueError( - f"expects rank(a) to be greater or equal than rank(b), but got {a}, {b}" - ) - b = append(b, 1, rank(a)) - return tuple(round_up(x, y) for x, y in zip(a, b)) - return ((a + b - 1) // b) * b - - -# -# Layout API (also used by tensors) -# - - -@dsl_user_op -def make_layout( - shape: Shape, *, stride: Union[Stride, None] = None, loc=None, ip=None -) -> Layout: - """Create a CuTe Layout object from shape and optional stride information. - - A Layout in CuTe represents the mapping between logical and physical coordinates of a tensor. - This function creates a Layout object that defines how tensor elements are arranged in memory. - - :param shape: Shape of the layout defining the size of each mode - :type shape: Shape - :param stride: Optional stride values for each mode, defaults to None - :type stride: Union[Stride, None] - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A new Layout object with the specified shape and stride - :rtype: Layout - - **Examples:** - - .. code-block:: python - - # Create a 2D compact left-most layout with shape (4,4) - layout = make_layout((4,4)) # compact left-most layout - - # Create a left-most layout with custom strides - layout = make_layout((4,4), stride=(1,4)) # left-most layout with strides (1,4) - - # Create a layout for a 3D tensor - layout = make_layout((32,16,8)) # left-most layout - - # Create a layout with custom strides - layout = make_layout((2,2,2), stride=(4,1,2)) # layout with strides (4,1,2) - - Note: - - If stride is not provided, a default compact left-most stride is computed based on the shape - - The resulting layout maps logical coordinates to physical memory locations - - The layout object can be used for tensor creation and memory access patterns - - Strides can be used to implement: - * Row-major vs column-major layouts - * Padding and alignment - * Blocked/tiled memory arrangements - * Interleaved data formats - - Stride is keyword only argument to improve readability, e.g. - * make_layout((3,4), (1,4)) can be confusing with make_layout(((3,4), (1,4))) - * make_layout((3,4), stride=(1,4)) is more readable - """ - if stride is not None and not is_congruent(shape, stride): - raise ValueError(f"shape and stride must be congruent") - - shape_val = _pack_shape(shape, loc=loc, ip=ip) - if stride is not None: - stride_val = _pack_stride(stride, loc=loc, ip=ip) - layout_ty = _cute_ir.LayoutType.get(shape_val, stride_val) - else: - stride_val = None - layout_ty = _cute_ir.LayoutType.get(shape_val) - - return _cute_ir.make_layout( - layout_ty, shape=shape_val, stride=stride_val, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: - """Create an identity layout with the given shape. - - An identity layout maps logical coordinates directly to themselves without any transformation. - This is equivalent to a layout with stride (1@0,1@1,...,1@(N-1)). - - :param shape: The shape of the layout - :type shape: Shape - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A new identity Layout object with the specified shape - :rtype: Layout - - **Examples:** - - .. code-block:: python - - # Create a 2D identity layout with shape (4,4) - layout = make_identity_layout((4,4)) # stride=(1@0,1@1) - - # Create a 3D identity layout - layout = make_identity_layout((32,16,8)) # stride=(1@0,1@1,1@2) - - Note: - - An identity layout is a special case where each coordinate maps to itself - - Useful for direct coordinate mapping without any transformation - """ - if not is_int_tuple(shape): - raise TypeError(f"expects a shape input, got {type(shape)}") - shape_val = _pack_shape(shape, loc=loc, ip=ip) - return _cute_ir.make_identity_layout(shape_val, loc=loc, ip=ip) - - -@dsl_user_op -def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Layout: - """Create a layout with a specific ordering of dimensions. - - This function creates a layout where the dimensions are ordered according to the - specified order parameter, allowing for custom dimension ordering in the layout. - - :param shape: The shape of the layout - :type shape: Shape - :param order: The ordering of dimensions - :type order: Shape - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A new Layout object with the specified shape and dimension ordering - :rtype: Layout - - **Examples:** - - .. code-block:: python - - # Create a row-major layout - layout = make_ordered_layout((4,4), order=(1,0)) - - # Create a column-major layout - layout = make_ordered_layout((4,4), order=(0,1)) # stride=(1,4) - - # Create a layout with custom dimension ordering for a 3D tensor - layout = make_ordered_layout((32,16,8), order=(2,0,1)) # stride=(128,1,16) - - Note: - - The order parameter specifies the ordering of dimensions from fastest-varying to slowest-varying - - For a 2D tensor, (0,1) creates a column-major layout, while (1,0) creates a row-major layout - - The length of order must match the rank of the shape - """ - shape_val = _pack_shape(shape, loc=loc, ip=ip) - order_val = _pack_int_tuple(order, loc=loc, ip=ip) - return _cute_ir.make_ordered_layout( - shape=shape_val, order=order_val, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_composed_layout( - inner, offset: IntTuple, outer: Layout, *, loc=None, ip=None -) -> ComposedLayout: - """Create a composed layout by composing an inner transformation with an outer layout. - - A composed layout applies a sequence of transformations - to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations - are applied from right to left. - - :param inner: The inner transformation (can be a Layout or Swizzle) - :type inner: Union[Layout, Swizzle] - :param offset: An integral offset applied between transformations - :type offset: IntTuple - :param outer: The outer (right-most) layout that is applied first - :type outer: Layout - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A new ComposedLayout representing the composition - :rtype: ComposedLayout - - **Examples:** - - .. code-block:: python - - # Create a basic layout - inner = make_layout(...) - outer = make_layout((4,4), stride=(E(0), E(1))) - - # Create a composed layout with an offset - composed = make_composed_layout(inner, (2,0), outer) - - Note: - - The composition applies transformations in the order: outer → offset → inner - - The stride divisibility condition must be satisfied for valid composition - - Certain compositions (like Swizzle with scaled basis) are invalid and will raise errors - - Composed layouts inherit many properties from the outer layout - """ - if not isinstance(outer, Layout): - raise TypeError( - f"expects the outer (or right-most or effectively visible) layout to be an affine layout, but got {outer}" - ) - if isinstance(inner, Swizzle) and has_scaled_basis(outer.stride): - raise TypeError(f"invalid composition {inner} o {offset} o {outer}") - offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) - return _cute_ir.make_composed_layout(inner, offset_val, outer, loc=loc, ip=ip) - - -@dsl_user_op -def cosize( - a: Union[Layout, ComposedLayout, Tensor], mode: List[int] = [], *, loc=None, ip=None -): - """Return size of codomain of layout or tensor. Return static value if type is static. - - :param a: Layout, ComposedLayout, or Tensor object - :type a: Union[Layout, ComposedLayout, Tensor] - :param mode: List of mode(s) for cosize calculation - :type mode: List[int], optional - :param loc: Location information for diagnostics, defaults to None - :type loc: optional - :param ip: Instruction pointer for diagnostics, defaults to None - :type ip: optional - :return: Static size of layout or tensor (fast fold) if static, or a dynamic Value - :rtype: Union[int, Value] - """ - if any(not is_static(m) for m in mode): - raise ValueError(f"expects static mode, but got {mode}") - - if isinstance(a, _Tensor): - a = a.value - res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - - -@dsl_user_op -def size_in_bytes( - dtype: Type[Numeric], layout: Union[Layout, ComposedLayout], *, loc=None, ip=None -): - """Calculate the size in bytes based on its data type and layout. - - :param dtype: The DSL numeric data type - :type dtype: Type[Numeric] - :param layout: The layout of the elements. If None, the function returns 0 - :type layout: Layout, optional - :param loc: Location information for diagnostics, defaults to None - :type loc: optional - :param ip: Instruction pointer for diagnostics, defaults to None - :type ip: optional - :return: The total size in bytes. Returns 0 if the layout is None - :rtype: int - """ - if not isinstance(dtype, NumericMeta): - raise TypeError(f"dtype must be a Numeric, but got {dtype}") - - if layout is None: - return 0 - elif isinstance(layout, ComposedLayout): - if not isinstance(layout.inner, Swizzle): - raise TypeError( - f"invalid composed layout {layout}, inner must be a Swizzle" - ) - else: - return cosize(layout.outer, loc=loc, ip=ip) * dtype.width // 8 - else: - return cosize(layout, loc=loc, ip=ip) * dtype.width // 8 - - -@dsl_user_op -def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): - if target_profile: - profile_val = _pack_coord(target_profile, loc=loc, ip=ip) - return _cute_ir.coalesce(input, target_profile=profile_val, loc=loc, ip=ip) - else: - return _cute_ir.coalesce(input, loc=loc, ip=ip) - - -@dsl_user_op -def crd2idx(coord: Coord, layout, *, loc=None, ip=None): - """ - Convert a multi-dimensional coordinate into a value using the specified layout. - - This function computes the inner product of the flattened coordinate and stride: - - index = sum(flatten(coord)[i] * flatten(stride)[i] for i in range(len(coord))) - - :param coord: A tuple or list representing the multi-dimensional coordinate - (e.g., (i, j) for a 2D layout). - :type coord: Coord - :param layout: A layout object that defines the memory storage layout, including shape and stride, - used to compute the inner product. - :type layout: Layout or ComposedLayout - :param loc: Optional location information for IR diagnostics. - :type loc: optional - :param ip: Optional instruction pointer or context for underlying IR functions. - :type ip: optional - :returns: The result of applying the layout transformation to the provided coordinate. - :rtype: Any type that the layout maps to - - Example: - - .. code-block:: python - - import cutlass.cute as cute - @cute.jit - def foo(): - L = cute.make_layout((5, 4), stride=(4, 1)) - idx = cute.crd2idx((2, 3), L) - # Computed as: 2 * 4 + 3 = 11 - print(idx) - foo() # Expected output: 11 - """ - coord_val = _pack_coord(coord, loc=loc, ip=ip) - if isinstance(layout, (tuple, int)): - layout = make_layout(layout, loc=loc, ip=ip) - - res = _cute_ir.crd2idx(coord_val, layout, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - - -@dsl_user_op -def recast_layout(new_type_bits, old_type_bits, src_layout, *, loc=None, ip=None): - return _cute_ir.recast_layout( - new_type_bits, old_type_bits, src_layout, loc=loc, ip=ip - ) - - -@dsl_user_op -def slice_and_offset(coord, src, *, loc=None, ip=None): - layout = slice_(src, coord, loc=loc, ip=ip) - offset = crd2idx(coord, src, loc=loc, ip=ip) - return layout, offset - - -@dsl_user_op -@lru_cache_ir() -def shape( - input: Union[Shape, Tensor, Layout, Tile], *, mode=None, loc=None, ip=None -) -> Shape: - """Returns the shape of a tensor, layout or tiler. - - For shapes, this function is identical to get. - - This function extracts the shape information from the input object. For tensors and layouts, - it returns their internal shape property. For tilers, it unpacks the shape from the tile - representation. - - :param input: The object to extract shape from - :type input: Union[Tensor, Layout, Tile] - :param mode: Optional mode selector to extract specific dimensions from the shape - :type mode: Optional[int] - :param loc: Source location for MLIR operation tracking - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation - :type ip: Optional[InsertionPoint] - :return: The shape of the input object, optionally filtered by mode - :rtype: Shape - - Example: - - .. code-block:: python - - # Get shape of a layout - l0 = cute.make_layout((2, 3, 4)) - s0 = cute.shape(l0) # => (2, 3, 4) - - # Get shape of a hierarchical tiler - l1 = cute.make_layout(1) - s1 = cute.shape((l0, l1)) # => ((2, 3, 4), 1) - - # Get specific mode from a shape - s2 = cute.shape(l0, mode=0) # => 2 - """ - if is_int_tuple(input): - return get(input, mode=mode) - - if isinstance(input, (Tensor, Layout)): - shp = input.shape - else: - val = _cute_ir.get_shape(_pack_tile(input, loc=loc, ip=ip)) - shp = _unpack_x_tuple(val, loc=loc, ip=ip) - return get(shp, mode=mode) - - -# -# Pointer API -# - - -@dsl_user_op -def recast_ptr( - ptr: Pointer, - swizzle_=None, - dtype: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Pointer: - if dtype is not None: - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - dtype = dtype.mlir_type - - value_type = ptr.type.value_type if dtype is None else dtype - swizzle = swizzle_.type.attribute if swizzle_ is not None else None - res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle) - return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip) - - -@dsl_user_op -def make_ptr( - dtype: Union[Type[Numeric], None], - value, - mem_space: AddressSpace = AddressSpace.generic, - *, - assumed_align=None, - loc=None, - ip=None, -) -> Pointer: - if dtype is None or not isinstance(dtype, NumericMeta): - raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}") - - if not isinstance(mem_space, AddressSpace): - raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") - - if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): - value = llvm.ptrtoint(T.i64(), value) - - if not is_integer(value): - raise TypeError(f"expects integer value, but got {type(value)}") - value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) - - bytes_per_elt = max(1, dtype.width // 8) - if assumed_align is None: - assumed_align = bytes_per_elt - - if bytes_per_elt % assumed_align != 0 and assumed_align % bytes_per_elt != 0: - raise ValueError( - f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa." - ) - - aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) - aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip) - - data_ty = T.i8() if dtype is None else dtype.mlir_type - ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) - return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) - - -# -# Tensor API -# - - -@dsl_user_op -def make_tensor( - iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None -) -> Tensor: - """Creates a tensor by composing an engine (iterator/pointer) with a layout. - - A tensor is defined as T = E ∘ L, where E is an engine (array, pointer, or counting iterator) - and L is a layout that maps logical coordinates to physical offsets. The tensor - evaluates coordinates by applying the layout mapping and dereferencing the engine - at the resulting offset. - - :param iterator: Engine component (pointer, iterator, or counting iterator) that provides - data access capabilities - :type iterator: Union[Pointer, IntTuple] - :param layout: Layout component that defines the mapping from logical coordinates to - physical offsets - :type layout: Union[Shape, Layout, ComposedLayout] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A tensor object representing the composition E ∘ L - :rtype: Tensor - - :raises ValueError: If iterator type is not supported - - **Examples:** - - .. code-block:: python - - # Create a tensor with row-major layout - layout = make_layout((64, 128), stride=(128, 1)) - tensor = make_tensor(ptr, layout) - - # Create a tensor with hierarchical layout - layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) - tensor = make_tensor(smem_ptr, layout) - - # Create a coord tensor - layout = make_layout(2, stride=16 * E(0)) - tensor = make_tensor(5, layout) - - Notes: - - The engine (iterator) must support random access operations - - Common engine types include raw pointers, arrays, and random-access iterators - - The layout defines both the shape (logical dimensions) and stride (physical mapping) - - Supports both direct coordinate evaluation T(c) and partial evaluation (slicing) - """ - if not isinstance(layout, (Layout, ComposedLayout)): - layout = make_layout(layout, loc=loc, ip=ip) - elif isinstance(layout, ComposedLayout) and layout.type.is_normal_layout: - layout = layout.outer - - ty = None - if is_integer(iterator) or isinstance(iterator, tuple): - iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) - ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) - elif isinstance(iterator, Pointer): - iterator = iterator.value - ty = _cute_ir.MemRefType.get(iterator.type, layout.type) - else: - raise TypeError(f"unsupported iterator type, got {type(iterator)}") - - return _cute_ir.make_view(result=ty, iter=iterator, layout=layout, loc=loc, ip=ip) - - -@dsl_user_op -def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: - """Creates an identity tensor with the given shape. - - An identity tensor maps each coordinate to itself, effectively creating a counting - sequence within the shape's bounds. This is useful for generating coordinate indices - or creating reference tensors for layout transformations. - - :param shape: The shape defining the tensor's dimensions. Can be a simple integer - sequence or a hierarchical structure ((m,n),(p,q)) - :type shape: Shape - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A tensor that maps each coordinate to itself - :rtype: Tensor - - **Examples:** - - .. code-block:: python - - # Create a simple 1D coord tensor - tensor = make_identity_tensor(6) # [0,1,2,3,4,5] - - # Create a 2D coord tensor - tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] - - # Create hierarchical coord tensor - tensor = make_identity_tensor(((2,1),3)) - # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] - - Notes: - - The shape parameter follows CuTe's IntTuple concept - - Coordinates are ordered colexicographically - - Useful for generating reference coordinates in layout transformations - """ - shape_val = _pack_shape(shape, loc=loc, ip=ip) - return _cute_ir.make_identity_tensor(shape_val, loc=loc, ip=ip) - - -@dsl_user_op -def make_fragment( - layout_or_shape: Union[Layout, Shape], - dtype: Type[Numeric], - *, - loc=None, - ip=None, -) -> Tensor: - if not issubclass(dtype, Numeric): - raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") - elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() - - # Alignment for register memory is useless(?), pick-up large enough number - # to allow .128 (> 16B) load store - alignment = 32 - layout = None - if not isinstance(layout_or_shape, Layout): - layout = make_layout(layout_or_shape, loc=loc, ip=ip) - else: - layout = layout_or_shape - - ptr_ty = _cute_ir.PtrType.get(elem_ty, AddressSpace.rmem, alignment) - res_ty = _cute_ir.MemRefType.get(ptr_ty, layout.type) - tensor = _cute_ir.memref_alloca(res_ty, layout=layout, loc=loc, ip=ip) - return _Tensor(tensor.value, dtype) - - -@overload -def make_fragment_like( - src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None -) -> Tensor: ... - - -@overload -def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... - - -@overload -def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... - - -@dsl_user_op -def make_fragment_like(src, dtype=None, *, loc=None, ip=None): - """Create tensor with a compact layout in the same shape as the source on stack. - - This function either creates a fragment tensor with compact layout in - same shape as the source layout or a new layout with the same shape as the source. - The strides of the new layout follow the order induced by the source's strides, with a - special handling of the 0th mode: it is always stride-1 and generated in column-major order - (LayoutLeft). - - :param src: The source layout or tensor whose shape will be matched - :type src: Union[Layout, ComposedLayout, Tensor] - :param dtype: The element type for the fragment tensor, defaults to None - :type dtype: Type[Numeric], optional - :param loc: Source location for MLIR operations, defaults to None - :type loc: Location, optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: InsertionPoint, optional - - :return: A new layout or fragment tensor with matching shape - :rtype: Union[Layout, Tensor] - - **Examples:** - - Creating a rmem tensor from a tensor: - - .. code-block:: python - - smem_tensor = cute.make_tensor(smem_ptr, layout) - frag_tensor = cute.make_fragment_like(smem_tensor, cutlass.Float32) - # frag_tensor will be a register-backed tensor with the same shape - - Creating a fragment with a different element type: - - .. code-block:: python - - tensor = cute.make_tensor(gmem_ptr, layout) - bool_frag = cute.make_fragment_like(tensor, cutlass.Boolean) - # bool_frag will be a register-backed tensor with Boolean elements - - **Notes** - - - When used with a Tensor, if a type is provided, it will create a new - fragment tensor with that element type. - - For layouts with ScaledBasis strides, the function creates a fragment - from the shape only. - - This function is commonly used in GEMM and other tensor operations to - create register storage for intermediate results. - - """ - if isinstance(src, (Layout, ComposedLayout)): - new_layout = None - # Create base fragment layout - if isinstance(src, Layout) and has_scaled_basis(src.stride): - # For scaled basis strides, create fragment from shape only - new_layout = _cute_ir.make_fragment_like( - make_layout(src.shape), loc=loc, ip=ip - ) - else: - # Otherwise use full source layout - new_layout = _cute_ir.make_fragment_like(src, loc=loc, ip=ip) - if dtype is not None: - # call make_fragment to convert layout to tensor - return make_fragment(new_layout, dtype, loc=loc, ip=ip) - else: - return new_layout - elif isinstance(src, Tensor): - if isinstance(src.type, _cute_ir.CoordTensorType): - if dtype is None: - raise ValueError( - "dtype must be provided when src is a coordinate tensor" - ) - - new_layout = _cute_ir.make_fragment_like( - make_layout(src.shape), loc=loc, ip=ip - ) - return make_fragment(new_layout, dtype, loc=loc, ip=ip) - else: - dtype = src.element_type if dtype is None else dtype - ty = dtype.mlir_type if dtype is not Boolean else T.i8() - new_tensor = _cute_ir.make_fragment_like( - src.value, elem_type=ty, loc=loc, ip=ip - ) - return _Tensor(new_tensor.value, dtype) - else: - raise TypeError( - f"src must be a Layout or ComposedLayout or tensor, got {type(src)}" - ) - - -@dsl_user_op -def recast_tensor( - src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None -): - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - - if dtype is Boolean: - dst_width = 8 - else: - dst_width = dtype.width - - if src.element_type is Boolean: - src_width = 8 - else: - src_width = src.element_type.width - - src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) - src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) - return make_tensor(src_iter, src_layout, loc=loc, ip=ip) - - -@dsl_user_op -def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: - offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) - if isinstance(tensor.iterator, Pointer): - return make_tensor(tensor.iterator + offset, tensor.layout) - elif is_integer(tensor.iterator) or isinstance(tensor.iterator, tuple): - new_iter = _cute_ir.add_offset( - _pack_int_tuple(tensor.iterator), _pack_int_tuple(offset) - ) - return make_tensor(_unpack_x_tuple(new_iter), tensor.layout) - else: - raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") - - -# -# Layout algebra -# - - -@overload -def composition( - lhs: Layout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None -) -> Layout: ... - - -@overload -def composition( - lhs: Tensor, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None -) -> Tensor: ... - - -@dsl_user_op -def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): - """ - Compose two layout representations using the CuTe layout algebra. - - Compose a left-hand layout (or tensor) with a right-hand operand into a new layout R, such that - for every coordinate c in the domain of the right-hand operand, the composed layout satisfies: - - R(c) = A(B(c)) - - where A is the left-hand operand provided as ``lhs`` and B is the right-hand operand provided as - ``rhs``. In this formulation, B defines the coordinate domain while A applies its transformation to - B's output, and the resulting layout R inherits the stride and shape adjustments from A. - - Satisfies: - cute.shape(cute.composition(lhs, rhs)) is compatible with cute.shape(rhs) - - :param lhs: The left-hand operand representing the transformation to be applied. - :type lhs: Layout or Tensor - :param rhs: The right-hand operand defining the coordinate domain. If provided as an int or tuple, - it will be converted to a tile layout. - :type rhs: Layout, Shape, or Tile, or int or tuple - :param loc: Optional location information for IR diagnostics. - :type loc: optional - :param ip: Optional instruction pointer or context for underlying IR functions. - :type ip: optional - :returns: A new composed layout R, such that for all coordinates c in the domain of ``rhs``, - R(c) = lhs(rhs(c)). - :rtype: Layout or Tensor - - Example: - - .. code-block:: python - - import cutlass.cute as cute - @cute.jit - def foo(): - # Create a layout that maps (i,j) to i*4 + j - L1 = cute.make_layout((2, 3), stride=(4, 1)) - # Create a layout that maps (i,j) to i*3 + j - L2 = cute.make_layout((3, 4), stride=(3, 1)) - # Compose L1 and L2 - L3 = cute.composition(L1, L2) - # L3 now maps coordinates through L2 then L1 - """ - rhs_val = rhs - if not isinstance(rhs, Layout) and isinstance(rhs, (int, tuple)): - rhs_val = _pack_tile(rhs, loc=loc, ip=ip) - if isinstance(lhs, _Tensor): - lhs = lhs.value - return _cute_ir.composition(lhs, rhs_val, loc=loc, ip=ip) - - -@dsl_user_op -def complement( - input: Layout, cotarget: Union[Layout, Shape], *, loc=None, ip=None -) -> Layout: - """ - Compute the complement layout of the input layout with respect to the cotarget. - - The complement of a layout A with respect to cotarget n is a layout A* such that - for every k in Z_n and c in the domain of A, there exists a unique c* in the domain - of A* where k = A(c) + A*(c*). - - This operation is useful for creating layouts that partition a space in complementary ways, - such as row and column layouts that together cover a matrix. - - :param input: The layout to compute the complement of - :type input: Layout - :param cotarget: The target layout or shape that defines the codomain - :type cotarget: Union[Layout, Shape] - :param loc: Optional location information for IR diagnostics - :type loc: optional - :param ip: Optional instruction pointer or context for underlying IR functions - :type ip: optional - :returns: The complement layout - :rtype: Layout - - Example: - - .. code-block:: python - - import cutlass.cute as cute - @cute.jit - def foo(): - # Create a right-major layout for a 4x4 matrix - row_layout = cute.make_layout((4, 4), stride=(4, 1)) - # Create a left-major layout that complements the row layout - col_layout = cute.complement(row_layout, 16) - # The two layouts are complementary under 16 - """ - if isinstance(cotarget, Layout): - return _cute_ir.complement(input, cotarget=cotarget, loc=loc, ip=ip) - else: - cotarget_val = _pack_shape(cotarget, loc=loc, ip=ip) - return _cute_ir.complement(input, cotarget=cotarget_val, loc=loc, ip=ip) - - -@dsl_user_op -def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: - if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(input)}") - return _cute_ir.right_inverse(input=input, loc=loc, ip=ip) - - -@dsl_user_op -def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: - if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(input)}") - return _cute_ir.left_inverse(input=input, loc=loc, ip=ip) - - -@overload -def logical_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def logical_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def logical_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.logical_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def zipped_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def zipped_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def zipped_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.zipped_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def tiled_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def tiled_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def tiled_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.tiled_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def flat_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def flat_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def flat_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.flat_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def raked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def raked_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def raked_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.raked_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def blocked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... -@overload -def blocked_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None -) -> ComposedLayout: ... - - -@dsl_user_op -def blocked_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.blocked_product(input=block, tiler=tiler, loc=loc, ip=ip) - - -@overload -def logical_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... -@overload -def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... - - -@dsl_user_op -def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value - if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res - - -@overload -def zipped_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... -@overload -def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... - - -@dsl_user_op -def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value - if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res - - -@overload -def tiled_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... -@overload -def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... - - -@dsl_user_op -def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value - if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res - - -@overload -def flat_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... -@overload -def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... - - -@dsl_user_op -def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value - if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res - - -# -# Higher-level utilties -# - - -@dsl_user_op -def max_common_layout( - a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None -) -> Layout: - a_layout = a.layout if isinstance(a, _Tensor) else a - b_layout = b.layout if isinstance(b, _Tensor) else b - - inv_b = right_inverse(b_layout, loc=loc, ip=ip) - common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) - - # some_ir_value == 1 generates a new IR Value which evaluates to True! - s = get(common.shape, mode=[0], loc=loc, ip=ip) - d = get(common.stride, mode=[0], loc=loc, ip=ip) - # Keep only the static identity component of the common layout - if isinstance(s, int) and isinstance(d, int) and d == 1: - # Truncate to the size of the contiguous vector (static stride-1 mode) - return composition(inv_b, get(common, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) - else: - return make_layout(1, stride=0, loc=loc, ip=ip) - - -@dsl_user_op -def max_common_vector( - a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None -) -> int: - a_layout = a.layout if isinstance(a, _Tensor) else a - b_layout = b.layout if isinstance(b, _Tensor) else b - - inv_b = right_inverse(b_layout, loc=loc, ip=ip) - common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) - - # Keep only the static identity component of the common layout - if ( - is_static(get(common.shape, mode=[0], loc=loc, ip=ip)) - and get(common.stride, mode=[0], loc=loc, ip=ip) == 1 - ): - # Truncate to the size of the contiguous vector (static stride-1 mode) - return get(common.shape, mode=[0], loc=loc, ip=ip) - else: - return 1 - - -@dsl_user_op -def tile_to_shape( - atom: Union[Layout, ComposedLayout], - trg_shape: Shape, - order: Shape, - *, - loc=None, - ip=None, -) -> Union[Layout, ComposedLayout]: - trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) - order = _pack_int_tuple(order, loc=loc, ip=ip) - return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) - - -@dsl_user_op -def local_partition( - target: Tensor, - tiler: Union[Layout, Shape], - index: Union[int, Numeric], - proj: XTuple = 1, - *, - loc=None, - ip=None, -) -> Tensor: - if isinstance(index, cutlass_arith.ArithValue): - index_val = index - else: - index_val = index.ir_value() - if index_val.type.width > 32: - raise NotImplementedError( - f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" - ) - return _cute_ir.local_partition( - input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip - ) - - -@dsl_user_op -def local_tile( - input: Tensor, - tiler: Union[Layout, Shape], - coord: Coord, - proj: XTuple = None, - *, - loc=None, - ip=None, -) -> Tensor: - tiler_val = _pack_shape(tiler, loc=loc, ip=ip) - coord_val = _pack_coord(coord, loc=loc, ip=ip) - if proj is not None: - if not isinstance(proj, tuple): - raise TypeError(f"Expects tuple for proj, but got {type(proj)}") - proj_val = _pack_coord(proj, loc=loc, ip=ip) - proj = proj_val.type.attribute - - return _cute_ir.local_tile( - input=input.value, - tile=tiler_val, - static_tile=None, - coord=coord_val, - static_coord=None, - proj=proj, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_layout_image_mask( - lay: Layout, coord: Coord, mode: int, *, loc=None, ip=None -) -> Int16: - """ - Makes a 16-bit integer mask of the image of a layout sliced at a given mode - and accounting for the offset given by the input coordinate for the other modes. - """ - if not is_static(lay): - raise ValueError( - f"make_layout_image_mask requires the layout to be static, but got {pretty_str(lay)}" - ) - r = rank(lay) - if rank(coord) != r: - raise ValueError( - f"the rank of the coordinate must be equal to the one of the layout, but got {pretty_str(coord)}" - ) - if mode > r or mode < 0: - raise ValueError(f"expects `mode` to be in [0,rank(lay)), but got {mode}") - # Given that we require the layout to be static, we can check that the mask fits in 16 bits - # This might be too conservative but safe - if cosize(lay) > 16: - raise ValueError("the mask may not fit into a 16-bit integer") - - # Replace the mode to keep with _ in the coordinate - slicer = tuple(None if idx == mode else x for idx, x in enumerate(coord)) - # Slice the layout with the slicer above and keep track of the offset - sliced_lay, offset = slice_and_offset(slicer, lay, loc=loc, ip=ip) - # Given that we replace only one mode with _, the rank of the slice should be 1 - assert rank(sliced_lay) == 1 - - # Create the mask of the image - mcast_mask = Int16(0) - for i in range(size(sliced_lay)): - mcast_mask = mcast_mask | (1 << sliced_lay(i)) - mcast_mask <<= offset - return Int16(mcast_mask) - - -#################################################################################################### -# -# Atom -# -#################################################################################################### - - -class Op(ABC): - """ - Operation abstract base class. - """ - - pass - - -class MmaOp(Op): - """ - MMA Operation abstract base class. - """ - - @abstractmethod - def _make_trait(self, *, loc=None, ip=None, **kwargs): - pass - - -class CopyOp(Op): - """ - Copy Operation abstract base class. - """ - - @abstractmethod - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ): - pass - - -class Trait(ABC): - """ - Trait abstract base class. - - Traits are internal-only classes used by Atoms that wrap the underlying IR Value. The Python - user should only interact with Ops and Atoms. - """ - - def __init__(self, value: ir.Value) -> None: - self.value = value - - def __extract_mlir_values__(self): - return [self.value] - - def __new_from_mlir_values__(self, values): - return self.__class__(values[0]) - - def set(self, field, value, *, loc=None, ip=None) -> None: - raise NotImplementedError( - "set not implemented, the requesting Atom has likely no runtime state" - ) - - def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: - return self.value - - -class Atom(ABC): - """ - Atom base class. - - An Atom is the composition of - - - a MMA or Copy Operation; - - an internal MMA or Copy Trait. - - An Operation is a pure Python class that is used to model a specific MMA or Copy instruction. - The Trait wraps the underlying IR Value and provides access to the metadata of the instruction - encoded using CuTe Layouts. When the Trait can be constructed straighforwardly from an - Operation, the ``make_mma_atom`` or ``make_copy_atom`` API should be used. There are cases where - constructing the metadata is not trivial and requires more information, for example to determine - the number of bytes copied per TMA instruction ("the TMA vector length"). In such cases, - dedicated helper functions are provided with an appropriate API such that the Atom is - constructed internally in an optimal fashion for the user. - """ - - def __init__(self, op: Op, trait: Trait) -> None: - self._op = op - self._trait = trait - - def __extract_mlir_values__(self): - return extract_mlir_values(self._trait) - - def __new_from_mlir_values__(self, values): - return self.__class__(self.op, new_from_mlir_values(self._trait, values)) - - @property - def op(self) -> Op: - return self._op - - @property - def type(self): - return self._trait.value.type - - @dsl_user_op - def set(self, modifier, value, *, loc=None, ip=None) -> None: - """ - Sets runtime fields of the Atom. - - Some Atoms have runtime state, for example a tcgen05 MMA Atom - - - .. code-block:: python - - tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) - tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) - - The ``set`` method provides a way to the user to modify such runtime state. Modifiable - fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom - instance internally validates the field as well as the value provided by the user to set - the field to. - """ - self._trait.set(modifier, value, loc=loc, ip=ip) - - def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: - return self._trait.unpack(loc=loc, ip=ip, **kwargs) - - -#################################################################################################### -# -# MMA Atoms, TiledMma, and ThrMma -# -#################################################################################################### - - -class MmaAtom(Atom): - """ - The MMA Atom class. - """ - - def __str__(self) -> str: - res = "MMA Atom\n" - res += " ThrID: " + pretty_str(self.thr_id) + "\n" - res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" - res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" - res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" - res += " TV Layout C: " + pretty_str(self.tv_layout_C) - return res - - # - # Properties - # - - @property - def thr_id(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_id) - - @property - def shape_mnk(self) -> Shape: - return _unpack_x_tuple(self._trait.value.type.shape_mnk) - - @property - def tv_layout_A(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_a_tv) - - @property - def tv_layout_B(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_b_tv) - - @property - def tv_layout_C(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_c_tv) - - # - # make_fragment - # - - @dsl_user_op - def make_fragment_A(self, input, *, loc=None, ip=None): - # input could be memref/shape/layout for tmem based fragment - if isinstance(input, _Tensor): - if self.op is not None: - self.op._verify_fragment_A(input, loc=loc, ip=ip) - input = input.value - if isinstance(input, tuple): - input = _pack_shape(input, loc=loc, ip=ip) - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.A, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def make_fragment_B(self, input, *, loc=None, ip=None): - if isinstance(input, _Tensor): - if self.op is not None: - self.op._verify_fragment_B(input, loc=loc, ip=ip) - input = input.value - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.B, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def make_fragment_C(self, input, *, loc=None, ip=None): - # input could be memref/shape/layout for tmem based fragment - if isinstance(input, _Tensor): - input = input.value - if isinstance(input, tuple): - input = _pack_shape(input, loc=loc, ip=ip) - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.C, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - -class TiledMma(MmaAtom): - """ - The tiled MMA class. - """ - - def __str__(self) -> str: - res = "Tiled MMA\n" - res += " Thr Layout VMNK: " + pretty_str(self.thr_layout_vmnk) + "\n" - res += " Permutation MNK: " + pretty_str(self.permutation_mnk) + "\n" - res += "MMA Atom\n" - res += " ThrID: " + pretty_str(self.thr_id) + "\n" - res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" - res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" - res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" - res += " TV Layout C: " + pretty_str(self.tv_layout_C) - return res - - # - # Properties - # - - @property - def tv_layout_A_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_a_tv_tiled) - - @property - def tv_layout_B_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_b_tv_tiled) - - @property - def tv_layout_C_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_c_tv_tiled) - - @property - def permutation_mnk(self) -> Tile: - return _unpack_x_tuple(self._trait.value.type.permutation_mnk) - - @property - def thr_layout_vmnk(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_layout_vmnk) - - @property - def size(self) -> int: - return self._trait.value.type.size - - # - # Tiler - # - - def get_tile_size(self, mode_idx: int) -> Shape: - assert (mode_idx >= 0) and (mode_idx < 3) - perm_tile = self.permutation_mnk[mode_idx] - if perm_tile is None: - thr_layout_vmnk = self.thr_layout_vmnk - atom_shape_mnk = self.shape_mnk - return size(atom_shape_mnk, mode=[mode_idx]) * size( - thr_layout_vmnk, mode=[mode_idx + 1] - ) - else: - return size(perm_tile) - - # - # get_slice - # - - def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrMma": - return ThrMma(self.op, self._trait, thr_idx) - - # - # partition_shape - # - - def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): - shape = _pack_shape(shape, loc=loc, ip=ip) - return _unpack_x_tuple( - _cute_ir.tiled_mma_partition_shape( - operand_id, self._trait.value, shape, loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_shape_A(self, shape_mk, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) - - @dsl_user_op - def partition_shape_B(self, shape_nk, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) - - @dsl_user_op - def partition_shape_C(self, shape_mn, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) - - # - # _thrfrg - # - - @overload - def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... - - @overload - def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... - - def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: - if isinstance(input, Tensor): - return make_tensor( - input.iterator, - self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), - ) - elif isinstance(input, Layout): - if not is_static(input.type): - raise ValueError(f"Expects a static layout but got {input.type}") - return _cute_ir.static( - self._trait.value.type.thrfrg(operand_id, input), loc=loc, ip=ip - ) - - raise ValueError( - f"Expects a layout or a tensor as input but got {type(input)=}" - ) - - def _thrfrg_A( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) - - def _thrfrg_B( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) - - def _thrfrg_C( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) - - -class ThrMma(TiledMma): - """ - The thread MMA class for modeling a thread-slice of a tiled MMA. - """ - - def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: - super().__init__(op, trait) - self._thr_idx = thr_idx - - def __new_from_mlir_values__(self, values): - return self.__class__( - self.op, new_from_mlir_values(self._trait, values), self.thr_idx - ) - - @property - def thr_idx(self): - return self._thr_idx - - @dsl_user_op - def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.A, - self._trait.value, - input_mk.value, - thr_idx, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.B, - self._trait.value, - input_nk.value, - thr_idx, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.C, - self._trait.value, - input_mn.value, - thr_idx, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: - """ - Makes an MMA Atom from an MMA Operation. - - This function creates an MMA Atom from a given MMA Operation. Arbitrary kw arguments can be - provided for Op-specific additional parameters. They are not used as of today. - - :param op: The MMA Operation to construct an Atom for - :type op: MmaOp - :return: The MMA Atom - :rtype: MmaAtom - """ - trait = op._make_trait(loc=loc, ip=ip, **kwargs) - return MmaAtom(op, trait) - - -@dsl_user_op -def make_tiled_mma( - op_or_atom: Union[Op, MmaAtom], - atom_layout_mnk=(1, 1, 1), - permutation_mnk=None, - *, - loc=None, - ip=None, - **kwargs, -) -> TiledMma: - """ - Makes a tiled MMA from an MMA Operation or an MMA Atom. - - :param op_or_atom: The MMA Operation or Atom - :type op_or_atom: Union[Op, MmaAtom] - :param atom_layout_mnk: A Layout describing the tiling of Atom across threads - :type atom_layout_mnk: Layout - :param permutation_mnk: A permutation Tiler describing the tiling of Atom across values including any permutation of such tiling - :type permutation_mnk: Tiler - :return: The resulting tiled MMA - :rtype: TiledMma - """ - if isinstance(op_or_atom, Op): - op = op_or_atom - atom = make_mma_atom(op_or_atom, loc=loc, ip=ip, **kwargs) - elif isinstance(op_or_atom, MmaAtom): - op = op_or_atom.op - atom = op_or_atom - else: - raise TypeError( - f"expected an MMA Op or Atom, but got an instance of {type(op_or_atom)}" - ) - if isinstance(atom_layout_mnk, tuple): - atom_layout_mnk = make_layout(atom_layout_mnk, loc=loc, ip=ip) - if rank(atom_layout_mnk) != 3: - raise ValueError(f"expects rank-3 MNK atom layout, but got {atom_layout_mnk}") - permutation_mnk_ty = None - if permutation_mnk is not None: - permutation_mnk_ty = _pack_tile(permutation_mnk, loc=loc, ip=ip).type - ty = _cute_nvgpu_ir.TiledMmaType.get( - atom._trait.value.type, - atom_layout_mnk.type, - permutation_mnk_ty, - ) - val = _cute_ir.make_tiled_mma(ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledMma(op, trait) - - -#################################################################################################### -# -# Copy Atoms, TiledCopy, and ThrCopy -# -#################################################################################################### - - -class CopyAtom(Atom): - """ - The Copy Atom class. - """ - - def __str__(self) -> str: - res = "Copy Atom\n" - res += " ThrID: " + str(self.thr_id) + "\n" - res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" - res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" - res += " Value type: " + str(self._trait.value.type.value_type) - return res - - # - # Properties - # - - @property - def value_type(self) -> Type[Numeric]: - return Numeric.from_mlir_type(self._trait.value.type.value_type) - - @property - def thr_id(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_id) - - @property - def layout_src_tv(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_src_tv) - - @property - def layout_dst_tv(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_dst_tv) - - -class TiledCopy(CopyAtom): - """ - The tiled Copy class. - """ - - def __str__(self) -> str: - res = "Tiled Copy\n" - res += " Tiler MN: " + pretty_str(self.tiler_mn) + "\n" - res += " TV Layout tiled: " + str(self.layout_tv_tiled) + "\n" - res += "Copy Atom\n" - res += " ThrID: " + str(self.thr_id) + "\n" - res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" - res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" - res += " Value type: " + str(self._trait.value.type.value_type) - return res - - # - # Properties - # - - @property - def layout_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_tv_tiled) - - @property - def tiler_mn(self) -> Tile: - return _unpack_x_tuple(self._trait.value.type.tiler_mn) - - @property - def layout_src_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_src_tv_tiled) - - @property - def layout_dst_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_dst_tv_tiled) - - @property - def size(self) -> int: - return self._trait.value.type.size - - # - # get_slice and retile - # - - def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrCopy": - return ThrCopy(self.op, self._trait, thr_idx) - - @dsl_user_op - def retile(self, src, *, loc=None, ip=None): - return _cute_ir.tiled_copy_retile( - tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip - ) - - -class ThrCopy(TiledCopy): - """ - The thread Copy class for modeling a thread-slice of a tiled Copy. - """ - - def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: - super().__init__(op, trait) - self._thr_idx = thr_idx - - def __new_from_mlir_values__(self, values): - return self.__class__( - self.op, new_from_mlir_values(self._trait, values), self.thr_idx - ) - - @property - def thr_idx(self): - return self._thr_idx - - @dsl_user_op - def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_copy_partition_S( - self._trait.value, src.value, thr_idx, loc=loc, ip=ip - ) - - @dsl_user_op - def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_copy_partition_D( - self._trait.value, dst.value, thr_idx, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_copy_atom( - op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs -) -> CopyAtom: - """ - Makes a Copy Atom from a Copy Operation. - - This function creates a Copy Atom from a given Copy Operation. Arbitrary kw arguments can be - provided for Op-specific additional parameters. - - Example: - - .. code-block:: python - - op = cute.nvgpu.CopyUniversalOp() - atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) - - :param op: The Copy Operation to construct an Atom for - :type op: CopyOp - :param copy_internal_type: An internal data type used to construct the source/destination layouts in unit of tensor elements - :type copy_internal_type: Type[Numeric] - :return: The Copy Atom - :rtype: CopyAtom - """ - trait = op._make_trait(copy_internal_type, loc=loc, ip=ip, **kwargs) - return CopyAtom(op, trait) - - -@dsl_user_op -def make_layout_tv( - thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None -) -> Tuple[Shape, Layout]: - """Create a thread-value layout for partitioning data tensors. - - This function creates a thread-value layout that maps between ``(thread_idx, value_idx)`` - coordinates and logical ``(M,N)`` coordinates. The thread layout must be compact to ensure - proper partitioning. - - This implements the thread-value partitioning pattern shown in - Figure TVLayout, where data is partitioned across threads and values within each thread. - - :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) - :type thr_layout: Layout - :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs within each thread - :type val_layout: Layout - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tuple containing ``tiler_mn`` and ``layout_tv`` - :rtype: Tuple[Shape, Layout] - - where: - * ``tiler_mn`` is tiler and ``shape(tiler_mn)`` is compatible with ``shape(zipped_divide(x, tiler_mn))[0]`` - * ``layout_tv``: Thread-value layout mapping (thread_idx, value_idx) -> (M,N) - - **Example:** - - .. code-block:: python - - tiler_mn, layout_tv = cute.make_layout_tv( - cute.make_layout((4, 8), stride=(8, 1)), cute.make_layout(2, stride=1) - ) - - Above code creates a TV layout that maps between thread/value coordinates - and the logical coordinates in a 8x8 matrix with: - - * thread block layout ``(4,8):(8,1)`` - * 2 elements per thread - """ - - if not isinstance(thr_layout, Layout): - raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}") - if not isinstance(val_layout, Layout): - raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}") - - # Take the raked_products to compute the Layout_MN - # (M,N) -> (thr_idx, val_idx) - layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip) - thr_size = size(thr_layout, loc=loc, ip=ip) - val_size = size(val_layout, loc=loc, ip=ip) - tmp = make_layout((thr_size, val_size), loc=loc, ip=ip) - # (thr_idx, val_idx) -> (M,N) - layout_tv = composition( - right_inverse(layout_mn, loc=loc, ip=ip), tmp, loc=loc, ip=ip - ) - - tiler_mn = product_each(layout_mn.shape, loc=loc, ip=ip) - - return (tiler_mn, layout_tv) - - -def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - if type(tiler_mn) is tuple: - tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) - - assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( - tiler_mn.type - ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" - assert is_static(layout_tv.type) and is_static( - tiler_mn.type - ), "layout tv and tiler mn must be static" - tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( - atom.type, layout_tv.type, tiler_mn.type - ) - - val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledCopy(atom.op, trait) - - -def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - """Create a tiled type given a TV partitioner and tiler. - - :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - :type atom: CopyAtom - :param layout_tv: Thread-value layout - :type layout_tv: Layout - :param tiler_mn: Tile size - :type tiler_mn: Tiler - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -@dsl_user_op -def make_tiled_copy_tv( - atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None -) -> TiledCopy: - """Create a tiled copy given separate thread and value layouts. - - A TV partitioner is inferred based on the input layouts. The input thread layout - must be compact. - - :param atom: Copy atom - :type atom: CopyAtom - :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) - :type thr_layout: Layout - :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs - :type val_layout: Layout - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) - tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -@dsl_user_op -def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_A_tiled, - (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_B_tiled, - (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_C_tiled, - (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_copy: Tiled copy - :type tiled_copy: TiledCopy - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_copy: Tiled copy - :type tiled_copy: TiledCopy - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): - """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. - - :param atom: Copy atom - :type atom: CopyAtom - :param mma: Tiled MMA - :type mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for partitioner - :rtype: TiledCopy - - :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV - """ - # Truncate the V-layout to just the Copy_Atom, keep the V-order - layoutC_tv = mma.tv_layout_C_tiled - val_layout_src = atom.layout_src_tv - num_val_src = size(val_layout_src, mode=[1], loc=loc, ip=ip) - num_val_layoutC_tv = size(layoutC_tv, mode=[1], loc=loc, ip=ip) - if num_val_src > num_val_layoutC_tv: - raise ValueError( - f"The number value of CopyAtom's source layout {num_val_src} " - f"is greater than the size of TiledMma's LayoutC_TV {num_val_layoutC_tv}" - ) - layout_TV = composition( - layoutC_tv, - make_layout( - (size(layoutC_tv, mode=[0], loc=loc, ip=ip), num_val_src), loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - # Recompute tiler and restride the TV layout for the new tiler - - # Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them - # Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA - mma_tiler = (mma.get_tile_size(0), mma.get_tile_size(1)) - - tiler_0 = filter( - composition( - make_layout(mma_tiler, stride=(1, 0), loc=loc, ip=ip), - layout_TV, - loc=loc, - ip=ip, - ), - loc=loc, - ip=ip, - ) - tiler_1 = filter( - composition( - make_layout(mma_tiler, stride=(0, 1), loc=loc, ip=ip), - layout_TV, - loc=loc, - ip=ip, - ), - loc=loc, - ip=ip, - ) - tiler = (tiler_0, tiler_1) - - tile2mma = composition( - make_layout(mma_tiler, loc=loc, ip=ip), tiler, loc=loc, ip=ip - ) - layout_tv = composition( - left_inverse(tile2mma, loc=loc, ip=ip), layout_TV, loc=loc, ip=ip - ) - - tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) - - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -#################################################################################################### -# -# cute.gemm and cute.copy -# -#################################################################################################### - - -@dsl_user_op -def gemm( - atom: MmaAtom, - d: Tensor, - a: Tensor, - b: Tensor, - c: Tensor, - *, - loc=None, - ip=None, - **kwargs, -) -> None: - """The GEMM algorithm. - - Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. - warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. - - All tensors must be partitioned according to the provided MMA Atom. - - For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread - election internally. Manual thread selection is not required in such cases. - - Following dispatch rules are supported: - - - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) - - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) - - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) - - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) - - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) - - :param atom: MMA atom - :type atom: MmaAtom - :param d: Destination tensor - :type d: Tensor - :param a: First source tensor - :type a: Tensor - :param b: Second source tensor - :type b: Tensor - :param c: Third source tensor - :type c: Tensor - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR, defaults to None - :type ip: Optional[InsertionPoint], optional - :param kwargs: Additional keyword arguments - :type kwargs: dict - :return: None - :rtype: None - """ - - a_rank = rank(a.shape) - b_rank = rank(b.shape) - c_rank = rank(c.shape) - d_rank = rank(d.shape) - - if a_rank != b_rank: - raise ValueError("`a` and `b` must have the same rank") - - if c_rank != d_rank: - raise ValueError("`c` and `d` must have the same rank") - - if a_rank == 1: - if c_rank > 2: - raise ValueError("`c` must have rank <= 2 when `a` has rank 1") - elif a_rank == 2: - if c_rank not in (2, 3): - raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") - elif a_rank == 3: - if c_rank != 3: - raise ValueError("`c` must have rank 3 when `a` has rank 3") - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) - - -@dsl_user_op -def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """Performs a basic element-wise copy. - - This functions **assumes** the following pre-conditions: - 1. `size(src) == size(dst)` - - When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the - element-wise loop is fully unrolled. - - :param src: Source tensor - :type src: Tensor - :param dst: Destination tensor - :type dst: Tensor - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - """ - - if is_static(src.shape) and is_static(dst.shape): - simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, src.element_type.width - ) - simt_copy = _cute_ir.atom(simt_copy_ty, loc=loc, ip=ip) - return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip) - - s = size(dst, loc=loc, ip=ip) - # Always generate an scf.for Op when one of the tensors is dynamic - for i in for_generate(0, s): - dst[i] = src[i] - yield_out() - - -@dsl_user_op -def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """Performs a basic predicated element-wise copy. - - This functions **assumes** the following pre-conditions: - 1. `size(src) == size(dst)` - 2. `size(src) == size(pred)` - - When all shapes are static, the pre-conditions are actually verified and the element-wise loop - is fully unrolled. - - """ - if src.element_type.width != dst.element_type.width: - raise NotImplementedError( - "basic_copy_if currently only supports equal source and destination " - "element type bit width" - ) - - if is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape): - return _basic_copy_if_static(pred, src, dst, loc=loc, ip=ip) - - s = size(dst, loc=loc, ip=ip) - # Always generate an scf.for Op when one of the tensors is dynamic - for i in for_generate(0, s): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) - yield_out() - - -# Version of basic_copy_if when src and dst have static shapes -# - verify size(src) == size(dst) == size(prd) -# - fully unroll the loop for now -def _basic_copy_if_static( - pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None -) -> None: - assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) - if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): - raise ValueError( - "basic_copy expects the size of source, destination, and predicate tensors to match" - ) - # Fully unrolled loop in the static case for now - for i in range(size(dst, loc=loc, ip=ip)): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) - - -@dsl_user_op -def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """ - Auto-vectorizing SIMT copy policy. - - Given a source and destination tensors that are statically shaped, this policy figures out the - largest safe vector width that the copy instruction can take and performs the copy. - """ - if src.element_type.width != dst.element_type.width: - raise NotImplementedError( - "autovec_copy currently only supports equal source and destination " - "element type bit width" - ) - - # We are going to dispatch to copy-with-atom which requires shapes to be static - if not is_static(src.shape) or not is_static(dst.shape): - raise ValueError( - "autovec_copy expects source and destination tensors to be statically shaped" - ) - - vec_layout = max_common_layout(src, dst, loc=loc, ip=ip) - num_common_elements = size(vec_layout, loc=loc, ip=ip) - - # Next we construct an upper-bound on the number bits that can be vectorized by considering - # - the maximum alignment of the layouts - # - the maximum alignment of the pointers - - upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) - upper_bound = math.gcd(upper_bound, num_common_elements) - upper_bound *= src.element_type.width - - # For our instructions, the alignment of the pointer is an upper bound to the vector width - # max_alignment, as opposed to alignment, takes into account possible address swizzling - upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) - upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) - - # Finally, we put a cap at 128b - num_bits_per_copy = math.gcd(upper_bound, 128) - - if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): - num_common_elements = num_bits_per_copy // src.element_type.width - - # 2 step logical divides ensuring that the divides are valid at every step - vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) - vec_dst = logical_divide(dst, vec_layout, loc=loc, ip=ip) - tiled_src = logical_divide( - vec_src, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip - ) - tiled_dst = logical_divide( - vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip - ) - - # Dispatch to copy with atom - simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, num_bits_per_copy - ) - simt_copy = _cute_ir.atom(simt_type, loc=loc, ip=ip) - return _cute_ir.copy( - simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip - ) - - # Failed to vectorize, use a basic copy - basic_copy(src, dst, loc=loc, ip=ip) - - -@dsl_user_op -def copy( - atom: CopyAtom, - src: Tensor, - dst: Tensor, - *, - pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, -) -> None: - """ - The Copy algorithm. - - The "copy with Atom" expects source and destination tensors to be partitioned according to the - provided Copy Atom. Some Atoms require additional Op-specific kw arguments, for example TMA - copies: - - .. code-block:: python - - cute.copy(tma_atom, src, dst, tma_bar_ptr=mbar_ptr, mcast_mask=mask) - - An additional predication tensor can be provided. If the partitioned tensors have the following - logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile - consistent with ``(ATOM_REST,REST_M,...)``. - - For Copy Atoms that require single-threaded execution, the copy op automatically handles thread - election internally. Manual thread selection is not required in such cases. - """ - if isinstance(src.type, _cute_ir.MemRefType) and isinstance( - dst.type, _cute_ir.MemRefType - ): - if src.element_type.width != dst.element_type.width: - raise TypeError( - "`copy` currently only supports equal source and destination " - "element type bit width" - ) - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - if isinstance(pred, Tensor): - pred = pred.value - return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) - - -@dsl_user_op -def copy_atom_call( - atom: CopyAtom, - src: Tensor, - dst: Tensor, - *, - pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, -) -> None: - """ - Execute a single copy atom operation. - - The copy_atom_call operation executes a copy atom with the given operands. - Following src/dst layout of atom are valid: - * ((atom_v)) - * (atom_v) - - Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would - require multiple atom operations, which contradicts the definition of a single copy atom call. - - Examples: - - .. code-block:: python - - # Call a copy atom operation - cute.copy_atom_call(copy_atom, src_tensor, dst_tensor) - - An additional predication tensor can be provided. If the partitioned tensors have the following - logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile - consistent with ``(ATOM_REST,REST_M,...)``. - """ - if isinstance(src.type, _cute_ir.MemRefType) and isinstance( - dst.type, _cute_ir.MemRefType - ): - if src.element_type.width != dst.element_type.width: - raise TypeError( - "`copy_atom_call` currently only supports equal source and destination " - "element type bit width" - ) - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - if isinstance(pred, Tensor): - pred = pred.value - return _cute_ir.copy_atom_call( - value, src.value, dst.value, pred=pred, loc=loc, ip=ip - ) - - -def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: - """ - The Prefetch algorithm. - - The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. - Prefetch is used for loading tensors from global memory to L2. - - Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch. - - .. code-block:: python - - cute.prefetch(tma_atom, src) - - For Copy Atoms that require single-threaded execution, the copy op automatically handles thread - election internally. Manual thread selection is not required in such cases. - """ - dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) - value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr) - return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) - -#################################################################################################### -# -# TensorSSA class (experimental) -# -#################################################################################################### - - -class ReductionOp(Enum): - ADD = auto() - MUL = auto() - MAX = auto() - MIN = auto() - INC = auto() - DEC = auto() - AND = auto() - OR = auto() - XOR = auto() - - def __str__(self): - return self.name.lower() - - -class TensorSSA(cutlass_arith.ArithValue): - """A class representing thread local data from CuTe Tensor in value semantic and immutable. - - :param value: Flatten vector as ir.Value holding logic data of SSA Tensor - :type value: ir.Value - :param shape: The nested shape in CuTe of the vector - :type shape: Shape - :param dtype: Data type of the tensor elements - :type dtype: Type[Numeric] - - :ivar _shape: The nested shape in CuTe of the vector - :ivar _dtype: Data type of the tensor elements - - :raises ValueError: If shape is not static - """ - - def __init__(self, value, shape: Shape, dtype: Type[Numeric]): - """Initialize a new TensorSSA object. - - :param value: Flatten vector as ir.Value holding logic data of SSA Tensor - :type value: ir.Value - :param shape: The nested shape in CuTe of the vector - :type shape: Shape - :param dtype: Data type of the tensor elements - :type dtype: Type[Numeric] - :raises ValueError: If shape is not static - """ - if not is_static(shape): - raise ValueError("dynamic shape is not supported") - - signed = dtype.signed if issubclass(dtype, Integer) else False - super().__init__(value, signed) - - self._shape = shape - self._dtype = dtype - self._layout = None - - @property - def dtype(self) -> Type[Numeric]: - return self._dtype - - @property - def element_type(self) -> Type[Numeric]: - return self._dtype - - @abstractmethod - def __extract_mlir_values__(self): - return [self] - - @abstractmethod - def __new_from_mlir_values__(self, values): - return TensorSSA(values[0], self.shape, self.dtype) - - def __str__(self): - return f"tensor_value<{self.type} o {self.shape}>" - - @property - def shape(self): - return self._shape - - @overload - def _apply_op(self, op, other: "TensorSSA", flip, *, loc, ip) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: cutlass_arith.ArithValue, flip, *, loc, ip - ) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: Union[int, float, bool], flip, *, loc, ip - ) -> "TensorSSA": ... - - def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): - def get_attr_for_type(ty, value): - if isinstance(ty, ir.IntegerType): - return ir.IntegerAttr.get(ty, value) - elif isinstance(ty, ir.FloatType): - return ir.FloatAttr.get(ty, value) - else: - raise TypeError(f"unsupported type: {ty}") - - # Canonicalize into Numeric - if isinstance(other, (int, float, bool)) or ( - not isinstance(other, TensorSSA) - and isinstance(other, cutlass_arith.ArithValue) - ): - other = as_numeric(other) - - # Promote types - lhs, rhs, res_type = _binary_op_type_promote(self, other) - - # Promote scalar to vector - if not isinstance(rhs, TensorSSA): - if isinstance(rhs, Numeric): - vect_val = vector.broadcast(lhs.type, rhs.ir_value(loc=loc, ip=ip)) - else: - elem_attr = get_attr_for_type(lhs.type.element_type, rhs) - vect_attr = ir.DenseElementsAttr.get_splat(lhs.type, elem_attr) - vect_val = arith.constant(lhs.type, vect_attr, loc=loc, ip=ip) - rhs = TensorSSA(vect_val, lhs.shape, lhs.dtype) - - if flip: - lhs, rhs = rhs, lhs - - if op in ( - operator.lt, - operator.le, - operator.gt, - operator.ge, - operator.eq, - operator.ne, - ): - res_type = Boolean - - assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" - - def _broadcast(s, t): - if s == 1: - return t - elif t == 1: - return s - elif s == t: - return s - else: - raise ValueError(f"cannot broadcast {s} and {t}") - - max_rank = max(rank(lhs.shape), rank(rhs.shape)) - lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank) - rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank) - res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape) - - # broadcast to the same shape - lhs = lhs.broadcast_to(res_shape) - rhs = rhs.broadcast_to(res_shape) - - if ( - op in (operator.add, operator.sub) - and lhs.dtype == Boolean - and rhs.dtype == Boolean - ): - res = op(lhs.to(Int32), rhs.to(Int32)) - zero = zeros_like(res) - res = res.__ne__(zero).to(res_type) - else: - lhs_val = lhs.maybe_downcast() - rhs_val = rhs.maybe_downcast() - - if issubclass(lhs.dtype, Integer): - lhs_val = lhs_val.with_signedness(lhs.dtype.signed) - - if issubclass(rhs.dtype, Integer): - rhs_val = rhs_val.with_signedness(rhs.dtype.signed) - - res_vect = op(lhs_val, rhs_val) - res = TensorSSA(res_vect, lhs._shape, res_type) - - return res - - def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": - """ - Broadcast the tensor to the target shape. - """ - # pad source shape to the same rank - shape = append(self.shape, 1, up_to_rank=rank(target_shape)) - if shape == target_shape: - return self - - def _check_broadcast(s, t): - if s != t and s != 1: - raise ValueError( - f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" - ) - - transform_leaf(_check_broadcast, shape, target_shape) - - # reshape to flatten N-D vector - flat_shp = flatten_to_tuple(shape) - temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) - temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) - - # broadcast to result N-D vector - flat_tgt_shp = flatten_to_tuple(target_shape) - temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) - temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) - - res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore - res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) - - return TensorSSA(res_1d_vect, target_shape, self.dtype) - - def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the results of tensor^other. - - :param other: The other tensor for exponent. - :type other: TensorSSA - :return: The power of the tensor. - :rtype: TensorSSA - """ - return self._apply_op(operator.pow, other, loc=loc, ip=ip) - - def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the results of other^tensor. - - :param other: The other tensor to compute power with. - :type other: TensorSSA - :return: The element-wise power of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) - - def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the sum of the tensor and another tensor. - - :param other: The other tensor to add. - :type other: TensorSSA - :return: The sum of the two tensors with the same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.add, other, loc=loc, ip=ip) - - def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the sum of the tensor and another tensor (reverse add) - - :param other: The other tensor to add. - :type other: TensorSSA - :return: The sum of the two tensors with the same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) - - def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the difference of the tensor and another tensor. - - :param other: The other tensor to subtract. - :type other: TensorSSA - :return: The subtraction of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.sub, other, loc=loc, ip=ip) - - def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the difference of the tensor and another tensor (reverse subtract) - - :param other: The other tensor to subtract. - :type other: TensorSSA - :return: The subtraction of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) - - def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the multiplication of the tensor and another tensor. - - :param other: The other tensor to multiply. - :type other: TensorSSA - :return: The multiplication of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mul, other, loc=loc, ip=ip) - - def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the multiplication of the tensor and another tensor (reverse multiply) - - :param other: The other tensor to multiply. - :type other: TensorSSA - :return: The multiplication of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) - - def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the modulo of the tensor and another tensor. - - :param other: The other tensor to compute modulo with. - :type other: TensorSSA - :return: The element-wise modulo of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mod, other, loc=loc, ip=ip) - - def __rmod__(self, other) -> "TensorSSA": - """ - Returns the modulo of the tensor and another tensor (reverse modulo) - - :param other: The other tensor to compute modulo with. - :type other: TensorSSA - :return: The element-wise modulo of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mod, other, flip=True) - - def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the floordiv(//) of the tensor and another tensor. - - :param other: The other tensor to compute floordiv with. - :type other: TensorSSA - :return: The floordiv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) - - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) - - :param other: The other tensor to compute floordiv with. - :type other: TensorSSA - :return: The floordiv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) - - def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the truediv(/) of the tensor and another tensor. - - :param other: The other tensor to compute truediv with. - :type other: TensorSSA - :return: The truediv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.truediv, other, loc=loc, ip=ip) - - def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the truediv(/) of the tensor and another tensor (reverse truediv) - - :param other: The other tensor to compute truediv with. - :type other: TensorSSA - :return: The truediv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) - - def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the comparison of the tensor and another tensor as mask - - :param other: The other tensor to compare. - :type other: TensorSSA - :return: The comparison of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.eq, other, loc=loc, ip=ip) - - def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise not equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self != other. - :rtype: TensorSSA - """ - return self._apply_op(operator.ne, other, loc=loc, ip=ip) - - def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise less than comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self < other. - :rtype: TensorSSA - """ - return self._apply_op(operator.lt, other, loc=loc, ip=ip) - - def __le__(self, other) -> "TensorSSA": - """ - Returns the element-wise less than or equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self <= other. - :rtype: TensorSSA - """ - return self._apply_op(operator.le, other) - - def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise greater than comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self > other. - :rtype: TensorSSA - """ - return self._apply_op(operator.gt, other) - - def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise greater than or equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self >= other. - :rtype: TensorSSA - """ - return self._apply_op(operator.ge, other, loc=loc, ip=ip) - - def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise XOR of the tensor and another tensor. - - :param other: The other tensor to perform XOR with. - :type other: TensorSSA - :return: The element-wise XOR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.xor, other) - - def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the bitwise XOR of the tensor and another tensor. - - :param other: The other tensor to compute XOR with. - :type other: TensorSSA - :return: The element-wise bitwise XOR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) - - def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise OR of the tensor and another tensor. - - :param other: The other tensor to perform OR with. - :type other: TensorSSA - :return: The element-wise OR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.or_, other) - - def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise OR of the tensor and another tensor. - - :param other: The other tensor to perform OR with. - :type other: TensorSSA - :return: The element-wise OR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.or_, other, flip=True) - - def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise AND of the tensor and another tensor. - - :param other: The other tensor to perform AND with. - :type other: TensorSSA - :return: The element-wise AND of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.and_, other) - - def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise AND of the tensor and another tensor. - - :param other: The other tensor to perform AND with. - :type other: TensorSSA - :return: The element-wise AND of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) - - def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the negation of the tensor. - - :return: The element-wise negation of the tensor - :rtype: TensorSSA - """ - - return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) - - def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): - # Coalesce and flatten source layout at terminal of coordinate - # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) - crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) - - # Flatten coordinate - flat_shp = flatten(crd_shp) - assert isinstance(flat_shp, tuple) and is_static(flat_shp) - # (C_0,(C_1,...), ...) -> (C_0,C_1,C_2,...) - flat_crd = flatten(crd) - - assert isinstance(flat_crd, tuple) and is_static(flat_crd) - return flat_shp, flat_crd - - def _build_result(self, res_vect, res_shp, *, loc=None, ip=None): - if isinstance(res_shp, ir.Value): - raise ValueError( - f"expects static shape and coordinates, but got {self._shape} and {crd}" - ) - - # cast back to 1D vector - res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) - res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) - return TensorSSA(res_1d_vect, res_shp, self.dtype) - - @dsl_user_op - def __getitem__( - self, crd: Coord, *, loc=None, ip=None - ) -> Union["TensorSSA", Numeric]: - """Access or slice tensor elements using coordinates. - - This method implements tensor evaluation T(c) = *(E + L(c)) where E is the iterator/engine - and L is the layout. It supports both direct element access and slicing operations. - - :param crd: Coordinate or slice specification for accessing tensor elements - :type crd: Coord - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Tensor element value or sliced subtensor - :rtype: Union[TensorSSA, Numeric] - - :raises ValueError: If coordinate access is invalid for the tensor layout - - **Examples:** - - .. code-block:: python - - # Create a fragment from rmem as shape (8, 4) - layout = make_layout((8, 4)) - tensor = make_fragment(layout, Float32) - frg = tensor.load() - - # Direct element access - val = frg[0] # Returns first element of fragment - val = frg[(0, 1)] # Returns element at (0, 1) - - # Slice access - sliced = frg[(3, None)] # Returns fragment slice - """ - # short-cut to no-op - if crd is None: - return self - - if not has_underscore(crd): - if self._layout is None: - self._layout = make_layout(self._shape, loc=loc, ip=ip) - idx = crd2idx(crd, self._layout, loc=loc, ip=ip) - idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) - res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) - return self.dtype(res_val) - - if not is_static(crd): - raise ValueError("dynamic coordinate is not supported") - - flat_shp, flat_crd = self._flatten_shape_and_coord(crd) - - multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) - # vector -> vector - tmp_vect = vector.shape_cast(multi_dim_ty, self) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self._shape, crd) - if isinstance(res_shp, ir.Value): - raise TypeError( - f"expects static shape and coordinates, but got {self._shape} and {crd}" - ) - - # Offsets is index of coordinates if NOT `_` otherwise 0 - offsets = [c if c is not None else 0 for c in flat_crd] - # Sizes is size of shapes if `_` otherwise 1 - sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] - # Logic stride to index vector. Only support stride-1 by vector - strides = [1] * rank(flat_shp) - - # Vector slice on N-D vector - res_ty = ir.VectorType.get(list(sizes), self.type.element_type) - res_vect = vector.extract_strided_slice( - res_ty, tmp_vect, offsets=offsets, sizes=sizes, strides=strides - ) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self._shape, crd) - return self._build_result(res_vect, res_shp, loc=loc, ip=ip) - - @dsl_user_op - def to(self, dtype: Type[Numeric], *, loc=None, ip=None): - """Convert the tensor to a different numeric type. - - :param dtype: The target numeric type to cast to. - :type dtype: Type[Numeric] - :return: A new tensor with the same shape but with elements cast to the target type. - :rtype: TensorSSA - :raises TypeError: If dtype is not a subclass of Numeric. - :raises NotImplementedError: If dtype is an unsigned integer type. - """ - if dtype is ir.Value: - return self - - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") - - src_dtype = self.dtype - if src_dtype == dtype: - return self - - # maybe downcast can lose signedness - src = self.maybe_downcast().with_signedness(self.signed) - if src_dtype.is_float and dtype.is_float: - res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) - elif src_dtype.is_float and issubclass(dtype, Integer): - res_vect = cutlass_arith.fptoi( - src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip - ) - elif issubclass(src_dtype, Integer) and dtype.is_float: - res_vect = cutlass_arith.itofp( - src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip - ) - else: - res_vect = cutlass_arith.int_to_int(src, dtype, loc=loc, ip=ip) - - return TensorSSA(res_vect, self._shape, dtype) - - def ir_value(self, *, loc=None, ip=None): - return self - - def ir_value_int8(self, *, loc=None, ip=None): - """ - Returns int8 ir value of Boolean tensor. - When we need to store Boolean tensor ssa, use ir_value_int8(). - - :param loc: Source location information, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: Optional[InsertionPoint], optional - :return: The int8 value of this Boolean - :rtype: ir.Value - """ - assert ( - self.element_type is Boolean - ), f"Only boolean type needs to be converted to int8, got {self.element_type}" - - if not hasattr(self, "_value_int8"): - self._value_int8 = arith.extsi( - T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip - ) - return self._value_int8 - - def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): - """ - Perform reduce on selected modes with given predefined reduction op. - - :param op: The reduction operator to use (operator.add or operator.mul) - :type op: operator - :param init_val: The initial value for the reduction - :type init_val: numeric - :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. - :type reduction_profile: Coord - - :return: The reduced tensor - :rtype: TensorSSA - - **Examples:** - - .. code-block:: python - - reduce(f32 o (4,)) - => f32 - - reduce(f32 o (4, 5)) - => f32 - reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) - => f32 o (4,) - reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) - => f32 o (4, (5,)) - """ - # short-cut to no-op - if reduction_profile is None: - return self - - if not is_weakly_congruent(reduction_profile, self.shape): - raise ValueError( - f"Expect reduction_profile be weakly congruent to the shape of the tensor, " - f"but got {reduction_profile} and {self.shape}" - ) - - if op is ReductionOp.ADD: - red_kind = vector.CombiningKind.ADD - elif op is ReductionOp.MUL: - red_kind = vector.CombiningKind.MUL - elif op is ReductionOp.MAX: - red_kind = vector.CombiningKind.MAXIMUMF - elif op is ReductionOp.MIN: - red_kind = vector.CombiningKind.MINIMUMF - else: - raise NotImplementedError( - f"{op} is not supported, expects one of " - f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" - ) - - elem_ty = self.element_type - # Canonicalize to `Numeric` and convert into MLIR value - init_val = as_numeric(init_val).ir_value(loc=loc, ip=ip) - - if depth(reduction_profile) == 0: - return vector.reduction( - elem_ty.mlir_type, red_kind, self, acc=init_val, loc=loc, ip=ip - ) - - flat_shp, flat_prof = self._flatten_shape_and_coord( - reduction_profile, loc=loc, ip=ip - ) - assert depth(flat_shp) == 1 and depth(flat_prof) == 1 - assert rank(flat_shp) == rank(flat_prof) - - temp_ty = ir.VectorType.get(list(flat_shp), elem_ty.mlir_type) - temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) - - if isinstance(flat_prof, tuple): - red_dims = [i for i, x in enumerate(flat_prof) if x is not None] - else: - red_dims = [0] - - temp_acc_shp = slice_(flat_shp, flat_prof, loc=loc, ip=ip) - temp_acc_ty = ir.VectorType.get(list(temp_acc_shp), elem_ty.mlir_type) - - init_val = vector.broadcast(temp_acc_ty, init_val, loc=loc, ip=ip) - res_vect = vector.multi_reduction( - red_kind, temp_vect, acc=init_val, reduction_dims=red_dims, loc=loc, ip=ip - ) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self.shape, reduction_profile, loc=loc, ip=ip) - return self._build_result(res_vect, res_shp, loc=loc, ip=ip) - - -@dsl_user_op -def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: - """ - Return a new TensorSSA of given shape and type, filled with fill_value. - - :param shape: Shape of the new tensor. - :type shape: tuple - :param fill_value: Value to fill the tensor with. - :type fill_value: scalar - :param dtype: Data type of the tensor. - :type dtype: Type[Numeric] - :return: Tensor of fill_value with the specified shape and dtype. - :rtype: TensorSSA - """ - size = product(shape, loc=loc, ip=ip) - if not is_static(size): - raise ValueError("shape must be static") - - if isinstance(fill_value, (ir.Value, int, float, bool)): - fill_value = dtype(fill_value) - elif isinstance(fill_value, Numeric): - fill_value = fill_value.to(dtype, loc=loc, ip=ip) - else: - raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") - - res_ty = T.vector(size, dtype.mlir_type) - res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - return TensorSSA(res_val, shape, dtype) - - -def full_like( - a: Union[TensorSSA, Tensor], - fill_value, - dtype: Union[None, Type[Numeric]] = None, - *, - loc=None, - ip=None, -) -> TensorSSA: - """ - Return a full TensorSSA with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: array_like - :param fill_value: Fill value. - :type fill_value: array_like - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Union[None, Type[Numeric]], optional - :return: Tensor of `fill_value` with the same shape and type as `a`. - :rtype: TensorSSA - - .. seealso:: - :func:`empty_like`: Return an empty array with shape and type of input. - :func:`ones_like`: Return an array of ones with shape and type of input. - :func:`zeros_like`: Return an array of zeros with shape and type of input. - :func:`full`: Return a new array of given shape filled with value. - - **Examples:** - - .. code-block:: python - - frg = cute.make_fragment(Float32, (2, 3)) - a = frg.load() - b = cute.full_like(a, 1.0) - """ - if not hasattr(a, "shape"): - raise TypeError(f"Expect `a` be shaped type, but got {type(a)}") - - return full( - a.shape, fill_value, dtype if dtype is not None else a.dtype, loc=loc, ip=ip - ) - - -def empty_like(a, dtype=None): - """ - Return a new TensorSSA with the same shape and type as a given array, without initializing entries. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Uninitialized tensor with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 0, dtype) - - -def ones_like(a, dtype=None): - """ - Return a TensorSSA of ones with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Tensor of ones with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 1, dtype) - - -def zeros_like(a, dtype=None, *, loc=None, ip=None): - """ - Return a TensorSSA of zeros with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Tensor of zeros with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 0, dtype, loc=loc, ip=ip) - - -def where( - cond: TensorSSA, x: TensorSSA, y: TensorSSA, *, loc=None, ip=None -) -> TensorSSA: - """ - Return elements chosen from x or y depending on condition. - - :param cond: Where True, yield x, where False, yield y. - :type cond: TensorSSA - :param x: Values from which to choose when condition is True. - :type x: TensorSSA - :param y: Values from which to choose when condition is False. - :type y: TensorSSA - :return: A tensor with elements from x where condition is True, and elements from y where condition is False. - :rtype: TensorSSA - """ - if x.dtype != y.dtype: - raise ValueError( - f"x and y must have the same dtype, but got {x.dtype} and {y.dtype}" - ) - - if cond.dtype != Boolean: - raise ValueError(f"cond must be Boolean type, but got {cond.dtype}") - - return TensorSSA( - arith.select(cond.ir_value(), x, y, loc=loc, ip=ip), x.shape, x.dtype - ) - - -def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: - """ - Test whether any tensor element evaluates to True. - - :param x: Input tensor. - :type x: TensorSSA - :return: Returns a TensorSSA scalar containing True if any element of x is True, False otherwise. - :rtype: TensorSSA - """ - is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) - return Boolean( - vector.reduction(T.bool(), vector.CombiningKind.OR, is_true, loc=loc, ip=ip) - ) - - -def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: - """ - Test whether all tensor elements evaluate to True. - - :param x: Input tensor. - :type x: TensorSSA - :return: Returns a TensorSSA scalar containing True if all elements of x are True, False otherwise. - :rtype: TensorSSA - """ - is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) - return Boolean( - vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) - ) - - -############################################################################## -# User defined struct -############################################################################## - - -class struct: - """ - Decorator to abstract C structure in Python DSL. - - **Usage:** - - .. code-block:: python - - # Supports base_dsl scalar int/float elements, array and nested struct: - @cute.struct - class complex: - real : cutlass.Float32 - imag : cutlass.Float32 - - - @cute.struct - class StorageA: - mbarA : cute.struct.MemRange[cutlass.Int64, stage] - compA : complex - intA : cutlass.Int16 - - - # Supports aligment for its elements: - @cute.struct - class StorageB: - a: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, size_a], 1024 - ] - b: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, size_b], 1024 - ] - x: cute.struct.Align[cutlass.Int32, 16] - compA: cute.struct.Align[complex, 16] - - - # Statically get size and alignment: - size = StorageB.__sizeof__() - align = StorageB.__alignof__() - - # Allocate and referencing elements: - storage = allocator.allocate(StorageB) - - storage.a[0] ... - storage.x ... - storage.compA.real ... - - :param cls: The struct class with annotations. - :return: The decorated struct class. - """ - - # inner class for defining a continuous memory region - class _MemRangeMeta(type): - """ - A metaclass for creating MemRange classes. - - This metaclass is used to dynamically create MemRange classes with specific - data types and sizes. - - :ivar _dtype: The data type of the MemRange. - :ivar _size: The size of the MemRange. - """ - - _dtype = None - _size = None - - def __new__(cls, name, bases, dct): - new_cls = super().__new__(cls, name, bases, dct) - return new_cls - - def __getitem__(cls, params) -> Type["struct.MemRange"]: - # get params from syntax: struct.MemRange[dtype, size] - if len(params) == 2: - dtype, size = params - else: - raise TypeError("Invalid struct.MemRange Arguments") - - if not struct._is_scalar_type(dtype): - raise TypeError("MemRange only support dsl scalar type!") - - # Create new class with proper name and parameters - new_cls = type( - f"struct.MemRange[{dtype.__name__}, {size}]", - (struct.MemRange,), - {"_dtype": dtype, "_size": size}, - ) - return new_cls - - @property - def size(cls): - return cls._size - - @property - def elem_width(cls): - return cls._dtype.width - - @property - def size_in_bytes(cls): - return cls.size * cls.elem_width // 8 - - class MemRange(metaclass=_MemRangeMeta): - """ - Defines a range of memory by `MemRange[T, size]`. - """ - - pass - - class _MemRangeData: - """ - Represents a range of memory. - - :param dtype: The data type. - :param size: The size of the memory range in bytes. - :param base: The base address of the memory range. - """ - - def __init__(self, dtype, size, base): - """ - Initializes a new memory range. - - :param dtype: The data type. - :param size: Size of the memory range in bytes. A size of **0** is accepted, but in that - case the range can only be used for its address (e.g. as a partition marker). - :param base: The base address of the memory range. - """ - self._dtype = dtype - self._size = size - self._base = base - - def data_ptr(self): - """ - Returns start pointer to the data in this memory range. - - :return: A pointer to the start of the memory range. - :raises AssertionError: If the size of the memory range is negative. - """ - assert self._size >= 0 - return recast_ptr(self._base, dtype=self._dtype) - - def get_tensor(self, layout, swizzle=None, dtype=None): - """ - Creates a tensor from the memory range. - - :param layout: The layout of the tensor. - :param swizzle: Optional swizzle pattern. - :param dtype: Optional data type; defaults to the memory range's data type if not specified. - :return: A tensor representing the memory range. - :raises TypeError: If the layout is incompatible with the swizzle. - :raises AssertionError: If the size of the memory range is not greater than zero. - """ - assert self._size > 0 - # make tensor - if isinstance(layout, ComposedLayout) and (swizzle is not None): - raise TypeError(f"incompatible layout with swizzle") - elem_type = self._dtype if dtype is None else dtype - ptr = recast_ptr(self._base, swizzle, dtype=elem_type) - res = make_tensor(ptr, layout) - return res - - def __getitem__(self, index: int) -> Any: - """ - Returns the element at the specified index in the memory range. - - :param index: The index of the element to retrieve. - :return: The element at the specified index. - :raises AssertionError: If the index is out of range. - """ - assert (index >= 0) and (index < self._size) - return self.data_ptr() + index - - # inner class for aligning a member type - class _AlignMeta(type): - """ - Aligns the given object by setting its alignment attribute. - - :param v: The object to align. Must be a struct, MemRange, or a scalar type. - :param align: The alignment value to set. - :raises TypeError: If the object is not a struct, MemRange, or a scalar type. - - :ivar _dtype: The data type to be aligned. - :ivar _align: The alignment of the data type. - """ - - _dtype = None - _align = None - - def __new__(cls, name, bases, dct): - return super().__new__(cls, name, bases, dct) - - def __getitem__(cls, params) -> Any: - if len(params) == 2: - dtype, align = params - assert align > 0 - else: - raise TypeError("Invalid struct.Align Arguments") - - if not struct._is_scalar_type(dtype) and not isinstance( - dtype, (struct, struct._MemRangeMeta) - ): - raise TypeError( - "align only can be applied to struct/MemRange/base_dsl scalar" - ) - - # Create new class with alignment - new_cls = type( - f"struct.Align[{dtype.__name__}, {align}]", - (struct.Align,), - {"_dtype": dtype, "_align": align}, - ) - return new_cls - - @property - def dtype(cls): - return cls._dtype - - @property - def align(cls): - return cls._align - - class Align(metaclass=_AlignMeta): - """ - Aligns the given type by `Align[T, alignment]`. - """ - - pass - - # util func for base dsl scalar types - @staticmethod - def _is_scalar_type(dtype): - """ - Checks if the given type is a scalar numeric type. - - :param dtype: The type to check. - :return: True if the type is a subclass of Numeric, False otherwise. - """ - return isinstance(dtype, type) and issubclass(dtype, Numeric) - - # calculate size and alignment - def __init__(self, cls): - """ - Initializes a new struct decorator instance. - - :param cls: The class representing the structured data type. - :raises TypeError: If the struct is empty. - """ - self._cls = cls - self.__name__ = f"struct::{cls.__name__}" - # Get the class annotations - self._annotations = cls.__annotations__ - # Create a dictionary to store the offsets - self._offsets: Dict[str, int] = {} - - # Calculate the offsets and alignment - offset = 0 - alignment = 1 - if len(self._annotations) == 0: - raise TypeError("Empty struct is not supported!") - for name, object in self._annotations.items(): - # get alignment of object - sub_align = 1 - if isinstance(object, struct._AlignMeta): - sub_align = object.align - object = object.dtype - - # switch addition order to support dynamic size - def add_offset(val): - return val + offset if isinstance(val, ir.Value) else offset + val - - # size of scalar - if struct._is_scalar_type(object): - dtype_size = max(1, object.width // 8) - sub_align = max(dtype_size, sub_align) - offset = self.align_offset(offset, sub_align) - self._offsets[name] = offset - offset = add_offset(dtype_size) - # size of array is size_in_bytes, alignment is elem_size - elif isinstance(object, struct._MemRangeMeta): - # Allow empty array as a free marker-only struct member. - # Use max(sub_align, ) because we might have in the future some - # object.elem_width less than 8, such as fp4, bit and others, - # and align_offset() does not support an alignment of 0. - sub_align = max(object.elem_width // 8, sub_align) - offset = self.align_offset(offset, sub_align) - self._offsets[name] = offset - offset = add_offset(object.size_in_bytes) - # size of struct - elif isinstance(object, struct): - sub_align = max(object.__alignof__(), sub_align) - offset = self.align_offset(offset, sub_align) - self._offsets[name] = offset - offset = add_offset(object.__sizeof__()) - else: - raise TypeError( - f"Struct element only support struct/array/base_dsl scalar, " - f"but got {object}" - ) - # Total aligment determined by the strictest requirement - alignment = max(alignment, sub_align) - # Total size determined by alignment - self._align_of = alignment - self._size_of = self.align_offset(offset, alignment) - - # create the __init__ method for decorated struct - def __call__(self, base: Any) -> None: - """ - Creates a new instance of the decorated struct. - - :param base: The base address of the struct. - :return: An instance of the decorated struct. - :raises TypeError: If the base pointer is not byte-sized. - """ - if base.type.value_type.width != 8: - raise TypeError("struct base ptr value type must be byte sized.") - # make an new object of user-defined decorated struct - # otherwise it will override same self._cls when new instance created - cls = self._cls() - setattr(cls, "_base", base) - for name, off in self._offsets.items(): - obj = self._annotations[name] - if isinstance(obj, struct._AlignMeta): - obj = obj.dtype - if struct._is_scalar_type(obj): - new_obj = recast_ptr(base + off, dtype=obj) - setattr(cls, name, new_obj) - elif isinstance(obj, struct._MemRangeMeta): - new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) - setattr(cls, name, new_obj) - elif isinstance(obj, struct): - new_obj = obj(base + off) - setattr(cls, name, new_obj) - else: - raise TypeError( - f"Struct element only support struct/array/base_dsl scalar, " - f"but got {obj}" - ) - return cls - - # get size - def size_in_bytes(self) -> int: - """ - Returns the size of the struct in bytes. - - :return: The size of the struct. - """ - return self._size_of - - # get size - def __sizeof__(self) -> int: - return self._size_of - - # get alignment - def __alignof__(self) -> int: - return self._align_of - - # util func for aligning offset - @staticmethod - def align_offset(offset, align): - """ - Return the round-up offset up to the next multiple of align. - """ - assert align > 0 and not ( - align & (align - 1) - ), "align should be a strictly positive power of 2." - return (offset + (align - 1)) & ~(align - 1) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py deleted file mode 100644 index daaa608262d00268ec1c47dfe32758c555f009b0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py +++ /dev/null @@ -1,445 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .core import TensorSSA -from .typing import Numeric -from cutlass._mlir.dialects import math, arith - -from typing import Callable, Union - - -def _math_op(func: Callable, fastmath: bool, *args, **kwargs): - """Dispatch the function to either a TensorSSA or a Numeric(Float). - - :param func: The function to dispatch - :param args: The input tensor or scalar - :param kwargs: The input tensor or scalar - """ - arg_type = type(args[0]) - for arg in args: - if not isinstance(arg, TensorSSA) and ( - not isinstance(arg, Numeric) or not type(arg).is_float - ): - raise TypeError( - f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}" - ) - if not isinstance(arg, arg_type): - raise TypeError( - f"Expected all inputs to be of type {arg_type}, but got {type(arg)}" - ) - - fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none - if isinstance(args[0], TensorSSA): - return TensorSSA( - func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype - ) - else: - args = [a.ir_value() for a in args] - return func(*args, fastmath=fastmath_flag) - - -def acos( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise arc cosine of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the arc cosine of each element in input tensor - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = acos(y) # Compute arc cosine - """ - return _math_op(math.acos, fastmath, a) - - -def asin( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise arc sine of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the arc sine of each element in input tensor - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = asin(y) # Compute arc sine - """ - return _math_op(math.asin, fastmath, a) - - -def atan( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise arc tangent of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the arc tangent of each element in input tensor - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = atan(y) # Compute arc tangent - """ - raise NotImplementedError("atan is not implemented") - return _math_op(math.atan, fastmath, a) - - -def atan2( - a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise arc tangent of two tensors. - - Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians - between the positive x-axis and the point given by the coordinates (b, a). - - :param a: First input tensor (y-coordinates) - :type a: Union[TensorSSA, Numeric] - :param b: Second input tensor (x-coordinates) - :type b: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the arc tangent of a/b element-wise - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - y = cute.make_fragment(ptr1, layout).load() # y coordinates - x = cute.make_fragment(ptr2, layout).load() # x coordinates - theta = atan2(y, x) # Compute angles - """ - return _math_op(math.atan2, fastmath, a, b) - - -def cos( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise cosine of the input tensor. - - :param a: Input tensor (in radians) - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the cosine of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = cos(y) # Compute cosine - """ - return _math_op(math.cos, fastmath, a) - - -def erf( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise error function of the input tensor. - - The error function is defined as: - erf(x) = 2/√π ∫[0 to x] exp(-t²) dt - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the error function value for each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = erf(y) # Compute error function - """ - return _math_op(math.erf, fastmath, a) - - -def exp( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise exponential of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the exponential of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = exp(y) # Compute exponential - """ - return _math_op(math.exp, fastmath, a) - - -def exp2( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise base-2 exponential of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing 2 raised to the power of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = exp2(y) # Compute 2^x - """ - return _math_op(math.exp2, fastmath, a) - - -def log( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise natural logarithm of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the natural logarithm of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = log(y) # Compute natural logarithm - """ - return _math_op(math.log, fastmath, a) - - -def log2( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise base-2 logarithm of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the base-2 logarithm of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = log2(y) # Compute log base 2 - """ - return _math_op(math.log2, fastmath, a) - - -def log10( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise base-10 logarithm of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the base-10 logarithm of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = log10(y) # Compute log base 10 - """ - return _math_op(math.log10, fastmath, a) - - -def rsqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise reciprocal square root of the input tensor. - - Computes 1/√x element-wise. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the reciprocal square root of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = rsqrt(y) # Compute 1/√x - """ - return _math_op(math.rsqrt, fastmath, a) - - -def sin( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise sine of the input tensor. - - :param a: Input tensor (in radians) - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the sine of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = sin(y) # Compute sine - """ - return _math_op(math.sin, fastmath, a) - - -def sqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise square root of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the square root of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = sqrt(y) # Compute square root - """ - return _math_op(math.sqrt, fastmath, a) - - -def tan( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise tangent of the input tensor. - - :param a: Input tensor (in radians) - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the tangent of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = tan(y) # Compute tangent - """ - return _math_op(math.tan, fastmath, a) - - -def tanh( - a: Union[TensorSSA, Numeric], fastmath: bool = False -) -> Union[TensorSSA, Numeric]: - """Compute element-wise hyperbolic tangent of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :return: Tensor containing the hyperbolic tangent of each element - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_fragment(layout) # Create tensor - y = x.load() # Load values - z = tanh(y) # Compute hyperbolic tangent - """ - return _math_op(math.tanh, fastmath, a) - - -__all__ = [ - "acos", - "asin", - "atan", - "atan2", - "cos", - "erf", - "exp", - "exp2", - "log", - "log10", - "log2", - "rsqrt", - "sin", - "sqrt", - "tan", - "tanh", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py deleted file mode 100644 index 0655bb09c05ae84714656020127cb41a4f28fbf6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from . import warp -from . import cpasync -from . import warpgroup -from . import tcgen05 - -from .common import * -from .helpers import * - - -# __all__ is required here for documentation generation -__all__ = [ - "OpError", - "MmaUniversalOp", - "CopyUniversalOp", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py deleted file mode 100644 index 1b0c4c82debcd55cd7f3d7df0e21920cda83ca18..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. -import enum -from dataclasses import dataclass -from typing import Type, Optional - -from cutlass.cutlass_dsl import DSLBaseError - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from .. import core -from ..typing import Float16, Float32, Float64, Numeric - - -class OpError(DSLBaseError): - """ - An exception class for Op construction errors. - """ - - def __init__( - self, op: core.Op, message: str, suggestion: Optional[str] = None - ) -> None: - if suggestion is None: - # Default suggestion - suggestion = "Check your Op construction code" - super().__init__( - message, - error_code=f"{op.__class__.__name__} error", - suggestion=suggestion, - ) - - -#################################################################################################### -# -# MMA Ops and Traits -# -#################################################################################################### - - -@dataclass(frozen=True) -class MmaUniversalOp(core.MmaOp): - """ - The universal MMA Operation. - - This Operation currently expects the A/B operands as well as the accumulator to share the same - data types. - - :param abacc_dtype: The data type for the A/B operands and the accumulator - :type abacc_dtype: Type[Numeric] - """ - - abacc_dtype: Type[Numeric] - - def __post_init__(self) -> None: - if self.abacc_dtype not in [Float16, Float32, Float64]: - raise OpError( - self, - f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64", - ) - - def __str__(self) -> str: - return ( - "universal MMA Operation using FMA" - f"\n A/B/Accumulator data type = {self.abacc_dtype}" - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait": - shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">') - atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get( - shape_mnk_attr, - self.abacc_dtype.mlir_type, - self.abacc_dtype.mlir_type, - self.abacc_dtype.mlir_type, - ) - return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip)) - - def _verify_fragment_A(self, input, *, loc=None, ip=None): - pass - - def _verify_fragment_B(self, input, *, loc=None, ip=None): - pass - -class MmaUniversalTrait(core.Trait): - pass - - -#################################################################################################### -# -# Copy Ops and Traits -# -#################################################################################################### - - -class MemoryOrder(enum.Enum): - WEAK = _cute_ir.MemOrderKind.WEAK - RELAXED = _cute_ir.MemOrderKind.RELAXED - ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE - RELEASE = _cute_ir.MemOrderKind.RELEASE - ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL - SC = _cute_ir.MemOrderKind.SC - MMIO = _cute_ir.MemOrderKind.MMIO - CONSTANT = _cute_ir.MemOrderKind.CONSTANT - VOLATILE = _cute_ir.MemOrderKind.VOLATILE - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir(self) -> _cute_ir.MemOrderKind: - return self.value - - -class MemoryScope(enum.Enum): - CTA = _cute_ir.MemScopeKind.CTA - CLUSTER = _cute_ir.MemScopeKind.CLUSTER - GPU = _cute_ir.MemScopeKind.GPU - SYS = _cute_ir.MemScopeKind.SYS - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir(self) -> _cute_ir.MemScopeKind: - return self.value - -@dataclass(frozen=True) -class CopyUniversalOp(core.CopyOp): - """ - The universal Copy Operation. - - When creating a Copy Atom out of this operation, the expected usage pattern is - - .. code-block:: python - - op = cute.nvgpu.CopyUniversalOp() - atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) - - - ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \ - or the destination TV Layout) in unit of tensor elements and is used for partitioning by \ - ``TiledCopy`` for example - - ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \ - execution. This can be larger than the width of the above data type. When not provided, \ - the compiler will do a best effort at auto-vectorizing. - """ - - def __str__(self) -> str: - return "universal Copy Operation" - - def _make_trait( - self, - copy_internal_type: Type[Numeric], - *, - loc=None, - ip=None, - **kwargs, - ) -> "CopyUniversalTrait": - num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) - memory_order = kwargs.get("memory_order", MemoryOrder.WEAK) - memory_scope = kwargs.get("memory_scope", MemoryScope.CTA) - if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): - raise ValueError( - "expects a 'num_bits_per_copy' kw argument of type int that is non-negative " - f"when creating a copy Atom for {self.__class__.__name__}" - ) - ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - copy_internal_type.mlir_type, - num_bits_per_copy, - memory_order._to_ir(), - memory_scope._to_ir(), - ) - return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class CopyUniversalTrait(core.Trait): - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py deleted file mode 100644 index 246360c2eb43ed5c4ca45127c579bc9f496caa08..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .copy import * -from .helpers import * - - -# __all__ is required here for documentation generation -__all__ = [ - # - # copy.py - # - "LoadCacheMode", - "CopyG2SOp", - "CopyBulkTensorTileG2SOp", - "CopyBulkTensorTileG2SMulticastOp", - "CopyBulkTensorTileS2GOp", - "CopyReduceBulkTensorTileS2GOp", - # - # helpers.py - # - "make_tiled_tma_atom", - "tma_partition", - "create_tma_multicast_mask", - "prefetch_descriptor", - "copy_tensormap", - "update_tma_descriptor", - "fence_tma_desc_acquire", - "cp_fence_tma_desc_release", - "fence_tma_desc_release", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py deleted file mode 100644 index a15495602304700d19803825d93004e0fa9fc509..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ /dev/null @@ -1,471 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from dataclasses import dataclass -from typing import Optional, Type - -from cutlass.cutlass_dsl import CuTeDSL, t - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ...core import CopyOp, Trait, ReductionOp -from ...typing import Int16, Pointer, Integer, Numeric -from ..common import OpError -from ..tcgen05.mma import CtaGroup - - -#################################################################################################### -# -# Aynchronous copies -# -#################################################################################################### - - -class LoadCacheMode(enum.Enum): - """ - An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction. - - See the `PTX documentation `__. - """ - - ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always - GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_ - STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming - LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use - NONE = _cute_nvgpu_ir.LoadCacheMode.none - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode: - return self.value - - -@dataclass(frozen=True) -class CopyG2SOp(CopyOp): - """ - Non-bulk asynchronous GMEM to SMEM Copy Operation. - - See the `PTX documentation `__. - """ - - cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS - - def __str__(self) -> str: - res = "cp.async GMEM -> SMEM copy Operation" - if self.cache_mode != LoadCacheMode.ALWAYS: - res += f"\n with cache mode = {self.cache_mode}" - return res - - def _make_trait( - self, - copy_internal_type: Type[t.Numeric], - *, - loc=None, - ip=None, - **kwargs, - ) -> "CopyG2STrait": - num_bits_per_copy = kwargs.get("num_bits_per_copy", None) - # Verify that the user provided enum values - if not isinstance(self.cache_mode, LoadCacheMode): - raise OpError( - self, - "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", - ) - if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): - raise ValueError( - "expects a 'num_bits_per_copy' kw argument of type int that is positive " - f"when creating a copy Atom for {self.__class__.__name__}" - ) - # Verify that the user provided enum values - if not isinstance(self.cache_mode, LoadCacheMode): - raise OpError( - self, - "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", - ) - ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get( - copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy - ) - return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class CopyG2STrait(Trait): - pass - - -#################################################################################################### -# -# Bulk tensor copies a.k.a TMA copies -# -#################################################################################################### - -TMA_MBAR_PTR_FIELD_NAME = "tma_bar" -TMA_MASK_FIELD_NAME = "mcast_mask" -TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr" - -# -# TMA GMEM -> SMEM copies -# - - -@dataclass(frozen=True) -class CopyBulkTensorTileG2SOp(CopyOp): - """ - Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit. - - See the `PTX documentation `__. - This Operation uses TMA in the ``.tile`` mode. - """ - - cta_group: CtaGroup = CtaGroup.ONE - - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - - def __post_init__(self) -> None: - if not isinstance(self.cta_group, CtaGroup): - raise OpError( - self, "expects the 'cta_group' parameter to be a CtaGroup instance" - ) - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": - raise OpError( - self, - f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - def __str__(self) -> str: - res = "cp.async GMEM -> SMEM bulk tensor copy Operation" - if self.cta_group == CtaGroup.TWO: - res += f"\n CTA group = 2" - return res - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "CopyBulkTensorTileG2SNonExecTrait": - raise NotImplementedError( - "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" - ) - - def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: - if self.cta_group == CtaGroup.ONE: - return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90 - elif self.cta_group == CtaGroup.TWO: - return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm - else: - assert False, "unrecognized self.cta_group" - - -class CopyBulkTensorTileG2SNonExecTrait(Trait): - # We allow kw args to be dropped so that the user can write common code for non-multicast - # and multicast loads. - def unpack( - self, - *, - loc=None, - ip=None, - tma_bar_ptr: Optional[Pointer] = None, - tma_desc_ptr: Optional[Pointer] = None, - **kwargs, - ): - """ - Custom implementation of unpack for non-executable TMAs. - - The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when - using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error. - """ - if not isinstance(tma_bar_ptr, Pointer): - raise ValueError( - "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" - ) - exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip - ) - if isinstance(tma_desc_ptr, Pointer): - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip - ) - return exec_value - - -# -# TMA GMEM -> SMEM multicast copies -# - - -@dataclass(frozen=True) -class CopyBulkTensorTileG2SMulticastOp(CopyOp): - """ - Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit. - - See the `PTX documentation `__. - This Operation uses TMA in the ``.tile`` mode. - """ - - cta_group: CtaGroup = CtaGroup.ONE - - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - - def __post_init__(self): - if not isinstance(self.cta_group, CtaGroup): - raise OpError( - self, "expects the 'cta_group' parameter to be a CtaGroup instance" - ) - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": - raise OpError( - self, - f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - def __str__(self) -> str: - res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation" - if self.cta_group == CtaGroup.TWO: - res += f"\n CTA group = 2" - return res - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait": - raise NotImplementedError( - "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" - ) - - def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: - if self.cta_group == CtaGroup.ONE: - return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast - elif self.cta_group == CtaGroup.TWO: - return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast - else: - assert False, "unrecognized self.cta_group" - - -class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): - def unpack( - self, - *, - loc=None, - ip=None, - tma_bar_ptr: Optional[Pointer] = None, - mcast_mask=None, - tma_desc_ptr=None, - ): - """ - Custom implementation of unpack for non-executable TMAs. - - The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be - provided when using `cute.copy`. - """ - if not isinstance(tma_bar_ptr, Pointer): - raise ValueError( - "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" - ) - if not isinstance(mcast_mask, Integer): - raise ValueError( - "expects a multicast mask to be provided via the mcast_mask kw argument" - ) - exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip - ) - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - if isinstance(tma_desc_ptr, Pointer): - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip - ) - return exec_value - - -# -# TMA SMEM -> GMEM copies -# - - -@dataclass(frozen=True) -class CopyBulkTensorTileS2GOp(CopyOp): - """ - Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. - - See the `PTX documentation `__. - This Operation uses TMA in the ``.tile`` mode. - """ - - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - - def __post_init__(self): - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - def __str__(self) -> str: - return "cp.async SMEM -> GMEM bulk tensor copy Operation" - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "CopyBulkTensorTileS2GTrait": - raise NotImplementedError( - "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" - ) - - -class CopyBulkTensorTileS2GTrait(Trait): - def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): - """ - Custom implementation of unpack for non-executable TMAs. - """ - exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) - if isinstance(tma_desc_ptr, Pointer): - attr_str = ( - f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>" - ) - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip - ) - return exec_value - -@dataclass(frozen=True) -class CopyReduceBulkTensorTileS2GOp(CopyOp): - """ - Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit. - - See the `PTX documentation `__. - This Operation uses TMA in the ``.tile`` mode. - """ - - reduction_kind: ReductionOp = ReductionOp.ADD - - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - - def __post__init__(self): - # Arch verification - arch = CuTeDSL.__get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - def __str__(self) -> str: - return "cp.async SMEM -> GMEM bulk tensor reduction Operation" - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "CopyReduceBulkTensorTileS2GTrait": - raise NotImplementedError( - "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" - ) - - def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind: - if self.reduction_kind == ReductionOp.ADD: - return _cute_nvgpu_ir.ReductionKind.ADD - elif self.reduction_kind == ReductionOp.MIN: - return _cute_nvgpu_ir.ReductionKind.MIN - elif self.reduction_kind == ReductionOp.MAX: - return _cute_nvgpu_ir.ReductionKind.MAX - elif self.reduction_kind == ReductionOp.INC: - return _cute_nvgpu_ir.ReductionKind.INC - elif self.reduction_kind == ReductionOp.DEC: - return _cute_nvgpu_ir.ReductionKind.DEC - elif self.reduction_kind == ReductionOp.AND: - return _cute_nvgpu_ir.ReductionKind.AND - elif self.reduction_kind == ReductionOp.OR: - return _cute_nvgpu_ir.ReductionKind.OR - elif self.reduction_kind == ReductionOp.XOR: - return _cute_nvgpu_ir.ReductionKind.XOR - else: - assert False, "unrecognized self.reduction_kind" - - -class CopyReduceBulkTensorTileS2GTrait(Trait): - def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): - """ - Custom implementation of unpack for non-executable TMAs. - """ - exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) - if isinstance(tma_desc_ptr, Pointer): - attr_str = ( - f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>" - ) - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip - ) - return exec_value - -__all__ = [ - "LoadCacheMode", - "CopyG2SOp", - "CopyBulkTensorTileG2SOp", - "CopyBulkTensorTileG2SMulticastOp", - "CopyBulkTensorTileS2GOp", - "CopyReduceBulkTensorTileS2GOp", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py deleted file mode 100644 index f64f07f167501d1805096373e915017612de4387..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ /dev/null @@ -1,341 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Optional, Tuple, Type, Union - -from cutlass.cutlass_dsl import dsl_user_op - -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import llvm - -from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta -from ... import core -from .copy import ( - CopyBulkTensorTileG2SOp, - CopyBulkTensorTileG2SMulticastOp, - CopyBulkTensorTileS2GOp, - CopyReduceBulkTensorTileS2GOp, - CopyBulkTensorTileG2SNonExecTrait, - CopyBulkTensorTileG2SMulticastNonExecTrait, - CopyBulkTensorTileS2GTrait, - CopyReduceBulkTensorTileS2GTrait, -) - - -@dsl_user_op -def make_tiled_tma_atom( - op: Union[ - CopyBulkTensorTileG2SOp, - CopyBulkTensorTileG2SMulticastOp, - CopyBulkTensorTileS2GOp, - CopyReduceBulkTensorTileS2GOp, - ], - gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], - cta_tiler: Tiler, - num_multicast: int = 1, - *, - internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[core.CopyAtom, Tensor]: - """ - Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM - buffer with the given Layout. - - Given - - - a GMEM tensor - - a SMEM layout - - a CTA-level Tiler - - this function figures out the bulk tensor asynchronous copy instruction to use with the maximum - "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided - layout and consistent with the provided Tiler. - - This function returns two results: - - 1. the Copy Atom - 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \ - that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \ - associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \ - similarly to any other CuTe tensors using the algebra. - - :param op: The Copy Operation to construct an Atom for - :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp] - :param gmem_tensor: The GMEM tensor involved in the Copy - :type gmem_tensor: Tensor - :param smem_layout: The SMEM layout to construct the Copy Atom for - :type smem_layout: Union[Layout, core.ComposedLayout] - :param cta_tiler: The CTA Tiler to use - :type cta_tiler: Tiler - :param num_multicast: The multicast factor - :type num_multicast: int - :param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit - :type internal_type: Type[Numeric] - :return: A Copy Atom for this Operation and the associated TMA tensor - :rtype: Tuple[core.CopyAtom, Tensor] - """ - - if internal_type is not None: - if not isinstance(internal_type, NumericMeta): - raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") - internal_type = internal_type.mlir_type - - cta_v_map = core.composition( - core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip), - cta_tiler, - loc=loc, - ip=ip, - ) - - if isinstance(op, CopyBulkTensorTileG2SOp): - if num_multicast != 1: - raise ValueError( - f"expects num_multicast to be 1 for non multicast G2S copies, " - f"but got {num_multicast}" - ) - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, - cta_v_map, - op._to_ir(), - num_multicast=num_multicast, - internal_type=internal_type, - loc=loc, - ip=ip, - ) - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] - elif isinstance(op, CopyBulkTensorTileG2SMulticastOp): - if num_multicast < 1: - raise ValueError( - f"expects num_multicast to be >= 1 for multicast G2S copies, " - f"but got {num_multicast}" - ) - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, - cta_v_map, - op._to_ir(), - num_multicast=num_multicast, - internal_type=internal_type, - loc=loc, - ip=ip, - ) - return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), - res[1], - ) - elif isinstance(op, CopyBulkTensorTileS2GOp): - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store( - gmem_tensor.value, - smem_layout, - cta_v_map, - internal_type=internal_type, - loc=loc, - ip=ip, - ) - return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] - elif isinstance(op, CopyReduceBulkTensorTileS2GOp): - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( - gmem_tensor.value, - smem_layout, - cta_v_map, - op._to_ir(), - internal_type=internal_type, - loc=loc, - ip=ip, - ) - return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] - else: - raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") - - -@dsl_user_op -def tma_partition( - atom: core.CopyAtom, - cta_coord: Coord, - cta_layout: Layout, - smem_tensor: Tensor, - gmem_tensor: Tensor, - *, - loc=None, - ip=None, -) -> Tuple[Tensor, Tensor]: - """ - Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom. - """ - cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip) - s, d = _cute_nvgpu_ir.atom_tma_partition( - atom._trait.value, - cta_coord=cta_coord_val, - cta_layout=cta_layout, - smem_tensor=smem_tensor.value, - gmem_tensor=gmem_tensor.value, - loc=loc, - ip=ip, - ) - return s, d - - -@dsl_user_op -def create_tma_multicast_mask( - cta_layout_vmnk: Layout, - cta_coord_vmnk: Coord, - mcast_mode: int, - *, - loc=None, - ip=None, -) -> Int16: - """ - Computes a multicast mask for a TMA load Copy. - - :param cta_layout_vmnk: The VMNK layout of the cluster - :type cta_layout_vmnk: Layout - :param cta_coord_vmnk: The VMNK coordinate of the current CTA - :type cta_coord_vmnk: Coord - :param mcast_mode: The tensor mode in which to multicast - :type mcast_mode: int - :return: The resulting mask - :rtype: Int16 - """ - if core.rank(cta_layout_vmnk) != 4: - raise ValueError( - f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}" - ) - if core.rank(cta_coord_vmnk) != 4: - raise ValueError( - f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}" - ) - return core.make_layout_image_mask( - cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip - ) - - -@dsl_user_op -def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None: - """ - Prefetches the TMA descriptor associated with the TMA Atom. - """ - _cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip) - - -@dsl_user_op -def copy_tensormap( - tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None -) -> None: - """ - Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided - pointer. - - :param tma_atom: The TMA Copy Atom - :type tma_atom: CopyAtom - :param tensormap_ptr: The pointer to the memory location to copy the tensormap to - :type tensormap_ptr: Pointer - """ - _cute_nvgpu_ir.copy_tma_desc( - tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip - ) - - -@dsl_user_op -def update_tma_descriptor( - tma_atom: core.CopyAtom, - gmem_tensor: Tensor, - tma_desc_ptr: Pointer, - *, - loc=None, - ip=None, -) -> None: - """ - Updates the TMA descriptor in the memory location pointed to by the provided pointer using - information from a TMA Copy Atom and the provided GMEM tensor. - - Specifically, the following fields of the TMA descriptor will be updated: - - 1. the GMEM tensor base address - 2. the GMEM tensor shape - 3. the GMEM tensor stride - - Other fields of the TMA descriptor are left unchanged. - - :param tma_atom: The TMA Copy Atom - :type tma_atom: CopyAtom - :param gmem_tensor: The GMEM tensor - :type gmem_tensor: Tensor - :param tensormap_ptr: The pointer to the memory location of the descriptor to udpate - :type tensormap_ptr: Pointer - """ - _cute_nvgpu_ir.update_tma_desc( - tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip - ) - - -@dsl_user_op -def fence_tma_desc_acquire( - tma_desc_ptr: Pointer, - *, - loc=None, - ip=None, -) -> None: - """ - See the `PTX documentation `__. - """ - tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() - llvm.inline_asm( - None, - [tma_desc_ptr_i64], - "fence.proxy.tensormap::generic.acquire.gpu [$0], 128;", - "l", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def cp_fence_tma_desc_release( - tma_desc_global_ptr: Pointer, - tma_desc_shared_ptr: Pointer, - *, - loc=None, - ip=None, -) -> None: - """ - See the `PTX documentation `__. - """ - tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value() - tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value() - llvm.inline_asm( - None, - [tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32], - "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;", - "l,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def fence_tma_desc_release(*, loc=None, ip=None) -> None: - """ - See the `PTX documentation `__. - """ - llvm.inline_asm( - None, - [], - "fence.proxy.tensormap::generic.release.gpu;", - "", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py deleted file mode 100644 index 9b4aa0dbb207dfad2832ddf7a80504c7cf591ff1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ /dev/null @@ -1,249 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Optional, Tuple, Type, Union - -from cutlass.cutlass_dsl import dsl_user_op - -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir - -from .. import core -from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta -from ...impl_utils import check_type_in -from .cpasync.copy import ( - CopyBulkTensorTileG2SOp, - CopyBulkTensorTileG2SNonExecTrait, - CopyBulkTensorTileG2SMulticastOp, - CopyBulkTensorTileG2SMulticastNonExecTrait, -) - - -#################################################################################################### -# -# TMA creation helpers for tcgen05 MMAs -# -#################################################################################################### - - -@dsl_user_op -def make_tiled_tma_atom_A( - op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], - gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], - mma_tiler_mnk: Shape, - tiled_mma: core.TiledMma, - cluster_shape_vmnk: Shape, - *, - internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[core.CopyAtom, Tensor]: - """ - Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation - accounting for the MK projections of the TiledMMA for A tensor loads. - - Given - - - a GMEM tensor - - a SMEM layout - - a MMA Tiler - - a TiledMma - - a Cluster-level shape - - this function figures out the bulk tensor asynchronous copy instruction to use with the maximum - "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided - layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode). - The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads. - - This function returns two results: - - 1. the Copy Atom - 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates - that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the - associated layout can output coordinates. Otherwise, TMA tensors can be partitioned - similarly to any other CuTe tensors using the algebra. - - :param op: The Copy Operation to construct an Atom for - :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] - :param gmem_tensor: The GMEM tensor to be loaded by this copy atom - :type gmem_tensor: Tensor - :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Union[Layout, core.ComposedLayout] - :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions - :type mma_tiler_mnk: Shape - :param tiled_mma: The TiledMMA that will consume the load as operands - :type tiled_mma: core.TiledMma - :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions - :type cluster_shape_vmnk: Shape - :param internal_type: An optional parameter for the internal data type to when element - type does not match the copy type - :type internal_type: Type[Numeric] - :return: A copy atom for this operation and the associated TMA coord tensor - :rtype: Tuple[core.CopyAtom, Tensor] - - """ - - if internal_type is not None: - if not isinstance(internal_type, NumericMeta): - raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") - internal_type = internal_type.mlir_type - check_type_in( - op, - [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], - "op", - "make_tiled_tma_atom_A", - ) - - ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) - mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:]) - g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) - cta_v_map = tiled_mma._thrfrg_A(g_tile) - cta_v_map = core.get(cta_v_map, mode=[1]) - cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) - - if isinstance(op, CopyBulkTensorTileG2SOp): - num_multicast = 1 - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - # multicast across the N-mode since those would share the same tile of A - num_multicast = core.size(cluster_shape_vmnk, mode=[2]) - - # res[0] = the IR Value for the non-executable atom instance - # res[1] = the IR Value for the associated TMA tensor - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, - cta_v_map, - op._to_ir(), - num_multicast=num_multicast, - internal_type=internal_type, - loc=loc, - ip=ip, - ) - if isinstance(op, CopyBulkTensorTileG2SOp): - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), - res[1], - ) - - -@dsl_user_op -def make_tiled_tma_atom_B( - op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], - gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], - mma_tiler_mnk: Shape, - tiled_mma: core.TiledMma, - cluster_shape_vmnk: Shape, - *, - internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[core.CopyAtom, Tensor]: - """ - Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation - accounting for the NK projections of the TiledMMA for B tensor loads. - - Given - - - a GMEM tensor - - a SMEM layout - - a MMA Tiler - - a TiledMma - - a Cluster-level shape - - this function figures out the bulk tensor asynchronous copy instruction to use with the maximum - "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided - layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode). - The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads. - - This function returns two results: - - 1. the Copy Atom - 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates - that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the - associated layout can output coordinates. Otherwise, TMA tensors can be partitioned - similarly to any other CuTe tensors using the algebra. - - :param op: The Copy Operation to construct an Atom for - :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] - :param gmem_tensor: The GMEM tensor to be loaded by this copy atom - :type gmem_tensor: Tensor - :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Union[Layout, core.ComposedLayout] - :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions - :type mma_tiler_mnk: Shape - :param tiled_mma: The TiledMMA that will consume the load as operands - :type tiled_mma: core.TiledMma - :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions - :type cluster_shape_vmnk: Shape - :param internal_type: An optional parameter for the internal data type to when element - type does not match the copy type - :type internal_type: Type[Numeric] - :return: A Copy Atom for this Operation and the associated TMA tensor - :rtype: Tuple[core.CopyAtom, Tensor] - - """ - - if internal_type is not None: - if not isinstance(internal_type, NumericMeta): - raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") - internal_type = internal_type.mlir_type - check_type_in( - op, - [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], - "op", - "make_tiled_tma_atom_B", - ) - - ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) - mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:]) - g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip) - cta_v_map = tiled_mma._thrfrg_B(g_tile) - cta_v_map = core.get(cta_v_map, mode=[1]) - cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) - - if isinstance(op, CopyBulkTensorTileG2SOp): - num_multicast = 1 - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - # multicast across the M-mode since those would share the same tile of B - num_multicast = core.size(cluster_shape_vmnk, mode=[1]) - - # res[0] = the IR Value for the non-executable atom instance - # res[1] = the IR Value for the associated TMA tensor - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, - cta_v_map, - op._to_ir(), - num_multicast=num_multicast, - internal_type=internal_type, - loc=loc, - ip=ip, - ) - if isinstance(op, CopyBulkTensorTileG2SOp): - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), - res[1], - ) - - -__all__ = [ - "make_tiled_tma_atom_A", - "make_tiled_tma_atom_B", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py deleted file mode 100644 index 2831bec6039b86a2231a5f05bdd3d1b9e0d891b0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .copy import * -from .mma import * -from .helpers import * - -# __all__ is required here for documentation generation -__all__ = [ - # - # copy.py - # - "Repetition", - "Pack", - "Unpack", - "Ld16x64bOp", - "Ld16x128bOp", - "Ld16x256bOp", - "Ld16x32bx2Op", - "Ld32x32bOp", - "St16x64bOp", - "St16x128bOp", - "St16x256bOp", - "St16x32bx2Op", - "St32x32bOp", - # - # mma.py - # - "OperandMajorMode", - "OperandSource", - "CtaGroup", - "Field", - "MmaTF32Op", - "MmaF16BF16Op", - "MmaI8Op", - "MmaFP8Op", - "MmaMXF8Op", - "MmaMXF4Op", - "MmaMXF4NVF4Op", - "SmemLayoutAtomKind", - # - # helpers.py - # - "make_smem_layout_atom", - "tile_to_mma_shape", - "commit", - "is_tmem_load", - "is_tmem_store", - "get_tmem_copy_properties", - "find_tmem_tensor_col_offset", - "make_tmem_copy", - "make_s2t_copy", - "get_s2t_smem_desc_tensor", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py deleted file mode 100644 index df954b09d5bcd30321df0dd65a9955fd30a0e811..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ /dev/null @@ -1,663 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from dataclasses import dataclass -from typing import Type - -from cutlass.cutlass_dsl import CuTeDSL - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ..common import OpError -from ...core import CopyOp, Trait -from ...typing import Numeric - -from .mma import CtaGroup - - -class Repetition(enum.Enum): - """ - An enumeration for the number of repetitions of a given TMEM copy within the instruction. - """ - - x1 = 1 - x2 = 2 - x4 = 4 - x8 = 8 - x16 = 16 - x32 = 32 - x64 = 64 - x128 = 128 - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - @classmethod - def _missing_(cls, value): - if isinstance(value, int): - if value == 1: - return Repetition.x1 - elif value == 2: - return Repetition.x2 - elif value == 8: - return Repetition.x8 - elif value == 16: - return Repetition.x16 - elif value == 32: - return Repetition.x32 - elif value == 64: - return Repetition.x64 - elif value == 128: - return Repetition.x128 - - -class Pack(enum.Enum): - """ - An enumeration for the possible packing patterns for TMEM to RMEM copies. - """ - - NONE = enum.auto() - PACK_16b_IN_32b = enum.auto() - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - -class Unpack(enum.Enum): - """ - An enumeration for the possible unpacking patterns for RMEM to TMEM copies. - """ - - NONE = enum.auto() - UNPACK_32b_IN_16b = enum.auto() - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - -@dataclass(frozen=True) -class _LdBase(CopyOp): - repeat: Repetition = Repetition.x1 - pack: Pack = Pack.NONE - - admissible_archs = [ - "sm_100a", - "sm_100f", - ] - - def __post_init__(self) -> None: - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - if not isinstance(self.repeat, Repetition): - raise OpError( - self, - "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", - ) - if not isinstance(self.pack, Pack): - raise OpError( - self, - "expects the 'pack' Op parameter to be a tcgen05.Pack instance", - ) - - def __str__(self) -> str: - res = ( - f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" - + f"\n number of repetitions = {self.repeat.value}" - ) - if self.pack == Pack.PACK_16b_IN_32b: - res += f"\n with 2x 16-bit to 32b packing" - return res - - -@dataclass(frozen=True) -class Ld16x64bOp(_LdBase): - """ - 16x64b TMEM load Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x64b`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Ld16x64bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( - copy_internal_type.mlir_type, - 16, - 64, - self.repeat.value, - ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, - ) - return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Ld16x64bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Ld16x128bOp(_LdBase): - """ - 16x128b TMEM load Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x128b`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.repeat == Repetition.x128: - raise OpError( - self, - "x128 repetition is not supported", - suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Ld16x128bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( - copy_internal_type.mlir_type, - 16, - 128, - self.repeat.value, - ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, - ) - return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Ld16x128bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Ld16x256bOp(_LdBase): - """ - 16x256b TMEM load Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x256b`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.repeat in (Repetition.x128, Repetition.x64): - raise OpError( - self, - "x64 and x128 repetition is not supported", - suggestion="choose one of x1, x2, x4, x8, x16, x32", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Ld16x256bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( - copy_internal_type.mlir_type, - 16, - 256, - self.repeat.value, - ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, - ) - return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Ld16x256bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Ld16x32bx2Op(_LdBase): - """ - 16x32bx2 TMEM load Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x32bx2`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Ld16x32bx2Trait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( - copy_internal_type.mlir_type, - 16, - 32, - self.repeat.value, - ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, - ) - return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Ld16x32bx2Trait(Trait): - pass - - -@dataclass(frozen=True) -class Ld32x32bOp(_LdBase): - """ - 32x32b TMEM load Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.32x32`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Ld32x32bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( - copy_internal_type.mlir_type, - 32, - 32, - self.repeat.value, - ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, - ) - return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Ld32x32bTrait(Trait): - pass - - -@dataclass(frozen=True) -class _StBase(CopyOp): - repeat: Repetition - unpack: Unpack = Unpack.NONE - - admissible_archs = [ - "sm_100a", - "sm_100f", - ] - - def __post_init__(self) -> None: - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - - if not isinstance(self.repeat, Repetition): - raise OpError( - self, - "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", - ) - if not isinstance(self.unpack, Unpack): - raise OpError( - self, - "expects the 'pack' Op parameter to be a tcgen05.Unpack instance", - ) - - def __str__(self) -> str: - res = ( - f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" - + f"\n number of repetitions = {self.repeat.value}" - ) - if self.unpack == Unpack.UNPACK_32b_IN_16b: - res += f"\n with 32-bit to 2x 16b unpacking" - return res - - -@dataclass(frozen=True) -class St16x64bOp(_StBase): - """ - 16x64b TMEM store Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x64`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "St16x64bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( - copy_internal_type.mlir_type, - 16, - 64, - self.repeat.value, - ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, - ) - return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class St16x64bTrait(Trait): - pass - - -@dataclass(frozen=True) -class St16x128bOp(_StBase): - """ - 16x128b TMEM store Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x128`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.repeat == Repetition.x128: - raise OpError( - self, - "x128 repetition is not supported", - suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "St16x128bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( - copy_internal_type.mlir_type, - 16, - 128, - self.repeat.value, - ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, - ) - return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class St16x128bTrait(Trait): - pass - - -@dataclass(frozen=True) -class St16x256bOp(_StBase): - """ - 16x256b TMEM store Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x256`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.repeat in (Repetition.x128, Repetition.x64): - raise OpError( - self, - "x64 and x128 repetition is not supported", - suggestion="choose one of x1, x2, x4, x8, x16, x32", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "St16x256bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( - copy_internal_type.mlir_type, - 16, - 256, - self.repeat.value, - ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, - ) - return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class St16x256bTrait(Trait): - pass - - -@dataclass(frozen=True) -class St16x32bx2Op(_StBase): - """ - 16x32x2b TMEM store Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.16x32x2`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "St16x32bx2Trait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( - copy_internal_type.mlir_type, - 16, - 32, - self.repeat.value, - ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, - ) - return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class St16x32bx2Trait(Trait): - pass - - -@dataclass(frozen=True) -class St32x32bOp(_StBase): - """ - 32x32b TMEM store Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.32x32`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "St32x32bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( - copy_internal_type.mlir_type, - 32, - 32, - self.repeat.value, - ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, - ) - return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class St32x32bTrait(Trait): - pass - - -@dataclass(frozen=True) -class _S2TCopyBase(CopyOp): - cta_group: CtaGroup - - admissible_archs = [ - "sm_100a", - "sm_100f", - ] - - def __post_init__(self) -> None: - # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - # Verify that the user provided enum values - if not isinstance(self.cta_group, CtaGroup): - raise OpError( - self, - "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", - ) - - def __str__(self) -> str: - res = ( - f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" - + f"\n CTA group = {self.cta_group}" - ) - - return res - - -@dataclass(frozen=True) -class Cp128x256bOp(_S2TCopyBase): - """ - 128x256b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.128x256b`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp128x256bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 128, - 256, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.none, - ) - return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp128x256bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Cp128x128bOp(_S2TCopyBase): - """ - 128x128b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.128x128b`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp128x128bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 128, - 128, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.none, - ) - return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp128x128bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Cp4x256bOp(_S2TCopyBase): - """ - 4x256b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.4x256b`` qualifier. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp4x256bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 4, - 256, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.none, - ) - return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp4x256bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Cp4x32x128bOp(_S2TCopyBase): - """ - 32x128b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp4x32x128bTrait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 32, - 128, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.x4, - ) - return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp4x32x128bTrait(Trait): - pass - - -@dataclass(frozen=True) -class Cp2x64x128b0213Op(_S2TCopyBase): - """ - 64x128b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp2x64x128b0213Trait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 64, - 128, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213, - ) - return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp2x64x128b0213Trait(Trait): - pass - - -@dataclass(frozen=True) -class Cp2x64x128b0123Op(_S2TCopyBase): - """ - 64x128b SMEM to TMEM Copy Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled. - """ - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "Cp2x64x128b0123Trait": - ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( - copy_internal_type.mlir_type, - 64, - 128, - self.cta_group.value, - _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123, - ) - return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class Cp2x64x128b0123Trait(Trait): - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py deleted file mode 100644 index 0ad27e62962e874da6707ac8a36863d5ed8f98a4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +++ /dev/null @@ -1,328 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import overload, Type, Tuple, Union - -from cutlass.cutlass_dsl import dsl_user_op - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import nvvm - -from ...typing import ( - Shape, - IntTuple, - Layout, - Tensor, - Int, - Numeric, - NumericMeta, - Int16, - Int32, -) -from ... import core -from .mma import SmemLayoutAtomKind, CtaGroup -from .copy import ( - Pack, - Unpack, - Ld16x64bOp, - Ld16x128bOp, - Ld16x256bOp, - Ld16x32bx2Op, - Ld32x32bOp, - St16x64bOp, - St16x128bOp, - St16x256bOp, - St16x32bx2Op, - St32x32bOp, -) - - -#################################################################################################### -# -# Helper functions for MMA -# -#################################################################################################### - - -@dsl_user_op -def make_smem_layout_atom( - kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None -) -> core.ComposedLayout: - """ - Makes a SMEM layout Atom. - - This function creates a composed layout in unit of elements consistent with the requested layout - Atom kind and element data type. - - :param kind: The kind of layout Atom - :type kind: SmemLayoutAtomKind - :param element_type: The element data type to construct the layout for - :type element_type: Type[Numeric] - :return: The SMEM layout atom - :rtype: core.ComposedLayout - """ - if not isinstance(element_type, NumericMeta): - raise TypeError(f"element_type must be a Numeric, but got {element_type}") - - if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): - num_contiguous_bits = 128 - sw = core.make_swizzle(0, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): - num_contiguous_bits = 256 - sw = core.make_swizzle(1, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): - num_contiguous_bits = 512 - sw = core.make_swizzle(2, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): - num_contiguous_bits = 1024 - sw = core.make_swizzle(3, 4, 3) - elif kind == SmemLayoutAtomKind.MN_SW128_32B: - num_contiguous_bits = 1024 - sw = core.make_swizzle(2, 5, 2) - else: - raise ValueError("unrecognized SMEM layout atom kind") - num_contiguous_elems = num_contiguous_bits // element_type.width - - if kind in ( - SmemLayoutAtomKind.MN_INTER, - SmemLayoutAtomKind.MN_SW32, - SmemLayoutAtomKind.MN_SW64, - SmemLayoutAtomKind.MN_SW128, - SmemLayoutAtomKind.MN_SW128_32B, - ): - # M/N-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) - ), - loc=loc, - ip=ip, - ) - else: - # K-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) - ), - loc=loc, - ip=ip, - ) - - -@overload -def tile_to_mma_shape( - atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None -) -> Layout: ... - - -@overload -def tile_to_mma_shape( - atom: core.ComposedLayout, - mma_tile_shape: Shape, - order: IntTuple = None, - *, - loc=None, - ip=None, -) -> core.ComposedLayout: ... - - -@dsl_user_op -def tile_to_mma_shape( - atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None -): - """ - Tiles a layout to an MMA shape. - """ - # Default order is colexicographical - if order is None: - order = tuple(range(core.rank(mma_tile_shape) - 1)) - if core.rank(order) != core.rank(mma_tile_shape) - 1: - raise ValueError( - f"rank(order)={core.rank(order)} must be equal to " - f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}" - ) - order_val = core._pack_int_tuple(order, loc=loc, ip=ip) - mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip) - - if not ( - core.is_static(atom) - and core.is_static(mma_tile_shape_val) - and core.is_static(order_val) - ): - raise ValueError("tile_to_mma_shape only supports static inputs") - - res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val) - return _cute_ir.static(res_ty, loc=loc, ip=ip) - - -@dsl_user_op -def commit( - mbar_ptr: core.Pointer, - mask=None, - cta_group: CtaGroup = CtaGroup.ONE, - *, - loc=None, - ip=None, -) -> None: - """ - Perform an arrive operation on a mbarrier upon completion of previous MMA operations. - - :param mbar_ptr: A pointer to the mbarrier in SMEM - :type mbar_ptr: Pointer - :param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to - :type mask: Int - """ - if cta_group == CtaGroup.ONE: - group = nvvm.Tcgen05GroupKind.CTA_1 - else: - assert cta_group == CtaGroup.TWO - group = nvvm.Tcgen05GroupKind.CTA_2 - - mbar_ptr = mbar_ptr.llvm_ptr - if mask is not None: - mask = Int16(mask).ir_value(loc=loc, ip=ip) - nvvm.tcgen05_commit_arrive( - mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip - ) - else: - nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip) - return - - -#################################################################################################### -# -# Helper functions for Copies -# -#################################################################################################### - - -def is_tmem_load(atom: core.CopyAtom) -> bool: - """ - Returns whether a CopyAtom instance is a TMEM load. - """ - return isinstance( - atom.op, - ( - Ld16x64bOp, - Ld16x128bOp, - Ld16x256bOp, - Ld16x32bx2Op, - Ld32x32bOp, - ), - ) - - -def is_tmem_store(atom: core.CopyAtom) -> bool: - """ - Returns whether a CopyAtom instance is a TMEM store. - """ - return isinstance( - atom.op, - ( - St16x64bOp, - St16x128bOp, - St16x256bOp, - St16x32bx2Op, - St32x32bOp, - ), - ) - - -def get_tmem_copy_properties( - atom: core.CopyAtom, -) -> Tuple[int, int, int, Union[Pack, Unpack]]: - """ - Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions, - and whether packing/unpacking is used). - """ - if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)): - num_dp, num_bits = 16, 64 - elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)): - num_dp, num_bits = 16, 128 - elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)): - num_dp, num_bits = 16, 256 - elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)): - num_dp, num_bits = 16, 32 - elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)): - num_dp, num_bits = 32, 32 - else: - raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}") - if is_tmem_load(atom): - return num_dp, num_bits, atom.op.repeat.value, atom.op.pack - else: - assert is_tmem_store(atom), "atom must be a TMEM store" - return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack - - -@dsl_user_op -def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int: - """ - Computes the TMEM column offset given a TMEM tensor. - - :param tmem_tensor: The TMEM tensor to use to compute the columns offset - :type tmem_tensor: Tensor - :return: The columns offset - :rtype: Int - """ - tmem_col_mask = 0x0000FFFF - offset = ( - core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip) - & tmem_col_mask - ) - if isinstance(offset, int): - return offset - return Int32(offset, loc=loc, ip=ip) - - -@dsl_user_op -def make_tmem_copy( - atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None -) -> core.TiledCopy: - """ - Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. - """ - tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy( - atom._trait.value, tmem_tensor.value, loc=loc, ip=ip - ) - new_trait = type(atom._trait)(tiled_copy_val) - return core.TiledCopy(atom.op, new_trait) - - -@dsl_user_op -def make_s2t_copy( - atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None -) -> core.TiledCopy: - """ - Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. - """ - tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy( - atom._trait.value, tmem_tensor.value, loc=loc, ip=ip - ) - new_trait = type(atom._trait)(tiled_copy_val) - return core.TiledCopy(atom.op, new_trait) - - -@dsl_user_op -def get_s2t_smem_desc_tensor( - atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None -) -> Tensor: - """ - Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. - """ - smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view( - atom._trait.value, smem_tensor.value, loc=loc, ip=ip - ) - return smem_desc_tensor diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py deleted file mode 100644 index 3a938523e130cf551c205669164e15e8bbd29132..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ /dev/null @@ -1,1041 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from dataclasses import dataclass -from typing import Type - -from cutlass.cutlass_dsl import CuTeDSL, T - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ..common import OpError -from ... import core -from ...core import Trait, _pack_shape, rank, depth, _Tensor -from ...typing import ( - Shape, - Float4E2M1FN, - Float8E8M0FNU, - Float8E5M2, - Float8E4M3FN, - Float16, - BFloat16, - Float32, - TFloat32, - Boolean, - Int8, - Uint8, - Int32, - Numeric, - AddressSpace, - Pointer, -) - - -#################################################################################################### -# -# MMA Ops and Traits -# -#################################################################################################### - - -class OperandMajorMode(enum.Enum): - """ - An enumeration for the majorness of the input operands of the MMA. - """ - - MN = _cute_ir.MajorMode.mn - K = _cute_ir.MajorMode.k - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - @classmethod - def _missing_(cls, value): - if isinstance(value, str): - value = value.upper() - if value == "MN": - return OperandMajorMode.MN - elif value == "K": - return OperandMajorMode.K - - def _to_ir(self) -> _cute_ir.MajorMode: - return self.value - - -class OperandSource(enum.Enum): - """ - An enumeration for the source memory location of the A input operand of the MMA. - """ - - TMEM = _cute_ir.MmaFragKind.tmem - SMEM = _cute_ir.MmaFragKind.smem_desc - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir(self) -> _cute_ir.MmaFragKind: - return self.value - - -class CtaGroup(enum.Enum): - """ - An enumeration for the ``cta_group`` qualifier of the MMA. - """ - - ONE = 1 - TWO = 2 - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - -class Field(enum.Enum): - """ - An enumeration for the fields of the MMA Atom that can be modified at runtime. - """ - - NEGATE_A = "neg_a" - NEGATE_B = "neg_b" - ACCUMULATE = "accum_c" - SFA = "sf_a" - SFB = "sf_b" - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir_field_name(self) -> str: - return self.value - - -# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code -@dataclass(frozen=True) -class MmaOp(core.MmaOp): - a_dtype: Type[Numeric] - b_dtype: Type[Numeric] - acc_dtype: Type[Numeric] - shape_mnk: Shape - cta_group: CtaGroup - a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode - - admissible_archs = [ - "sm_100a", - "sm_100f", - ] - - def __post_init__(self) -> None: - # Verify arch - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - # Verify that the user provided enum values - if not isinstance(self.cta_group, CtaGroup): - raise OpError( - self, - "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", - ) - if not isinstance(self.a_src, OperandSource): - raise OpError( - self, - "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", - ) - if not isinstance(self.a_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", - ) - if not isinstance(self.b_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", - ) - # Verify the instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): - raise OpError( - self, - f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " - f"but got {self.shape_mnk}", - ) - m, n = self.shape_mnk[0], self.shape_mnk[1] - if self.cta_group == CtaGroup.ONE: - if m not in [64, 128]: - raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}") - if m == 64: - if (n < 8) or (n > 256) or (n % 8 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", - ) - elif m == 128: - if (n < 16) or (n > 256) or (n % 16 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}", - ) - else: - if m not in [128, 256]: - raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") - if (n < 32) or (n > 256) or (n % 32 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}", - ) - - def __str__(self) -> str: - return ( - self.__class__.descriptive_name # type: ignore - + f"\n A data type = {self.a_dtype}" - + f"\n B data type = {self.b_dtype}" - + f"\n Accumulator data type = {self.acc_dtype}" - + f"\n CTA group = {self.cta_group}" - + f"\n A source location = {self.a_src}" - + f"\n A major mode = {self.a_major_mode}" - + f"\n B major mode = {self.b_major_mode}" - + f"\n Instruction shape MNK = {self.shape_mnk}" - ) - - def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand A, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand B, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - -class MmaTrait(Trait): - admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] - - def set(self, field, value, *, loc=None, ip=None) -> None: - if field not in self.admissible_fields: - raise ValueError( - f"expects field to be one of {self.admissible_fields}, but got {field}" - ) - field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>" - attr = ir.Attribute.parse(field_name) - self.value = _cute_nvgpu_ir.atom_set_value( - self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - - -# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code -@dataclass(frozen=True) -class BlockScaledMmaOp(core.MmaOp): - a_dtype: Type[Numeric] - b_dtype: Type[Numeric] - acc_dtype: Float32 - sf_dtype: Type[Numeric] - sf_vec_size: int - shape_mnk: Shape - cta_group: CtaGroup - a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode - - admissible_archs = [ - "sm_100a", - ] - - def __post_init__(self) -> None: - # Verify arch - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - # Verify that the user provided enum values - if not isinstance(self.cta_group, CtaGroup): - raise OpError( - self, - "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", - ) - if not isinstance(self.a_src, OperandSource): - raise OpError( - self, - "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", - ) - if not isinstance(self.a_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", - ) - if not isinstance(self.b_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", - ) - # Verify the instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): - raise OpError( - self, - f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " - f"but got {self.shape_mnk}", - ) - m, n = self.shape_mnk[0], self.shape_mnk[1] - if self.cta_group == CtaGroup.ONE: - if m != 128: - raise OpError(self, f"expects the M-mode to be 128, but got {m}") - - if (n < 8) or (n > 256) or (n % 8 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", - ) - else: - if m not in [128, 256]: - raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") - if (n < 16) or (n > 256) or (n % 16 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", - ) - if self.sf_vec_size not in [16, 32]: - raise OpError( - self, - f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}", - ) - - def __str__(self) -> str: - return ( - self.__class__.descriptive_name # type: ignore - + f"\n A data type = {self.a_dtype}" - + f"\n B data type = {self.b_dtype}" - + f"\n Accumulator data type = {self.acc_dtype}" - + f"\n Scale factor data type = {self.sf_dtype}" - + f"\n Scale factor vector size = {self.sf_vec_size}" - + f"\n CTA group = {self.cta_group}" - + f"\n A source location = {self.a_src}" - + f"\n A major mode = {self.a_major_mode}" - + f"\n B major mode = {self.b_major_mode}" - + f"\n Instruction shape MNK = {self.shape_mnk}" - ) - - def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand A, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand B, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - -class BlockScaledMmaTraits(Trait): - admissible_fields = [ - Field.ACCUMULATE, - Field.NEGATE_A, - Field.NEGATE_B, - Field.SFA, - Field.SFB, - ] - - def set(self, field, value, *, loc=None, ip=None) -> None: - if field not in self.admissible_fields: - raise ValueError( - f"expects field to be one of {self.admissible_fields}, but got {field}" - ) - if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]: - value = Boolean(value).ir_value(loc=loc, ip=ip) - elif field in [Field.SFA, Field.SFB]: - if not isinstance(value, Pointer): - raise ValueError( - f"expects value to be a pointer for {field}, but got {type(value).__name__}" - ) - value = value.value - - field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>" - attr = ir.Attribute.parse(field_name) - self.value = _cute_nvgpu_ir.atom_set_value( - self.value, attr, value, loc=loc, ip=ip - ) - - -# -# TF32 MMA -# - - -@dataclass(frozen=True) -class MmaTF32Op(MmaOp): - """ - TF32 tcgen05 MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::tf32`` qualifier. - """ - - descriptive_name = "tcgen05 TF32 MMA Operation" - - def __init__( - self, - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - TFloat32, - TFloat32, - Float32, - instruction_shape, - cta_group, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Verify the instruction shape - instruction_k = 8 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.a_src._to_ir(), - 0, - ) - return MmaTF32Trait( - _cute_nvgpu_ir.make_sm100_mma( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -class MmaTF32Trait(MmaTrait): - pass - - -# -# F16/BF16 MMA -# - - -@dataclass(frozen=True) -class MmaF16BF16Op(MmaOp): - """ - F16/BF16 tcgen05 MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::f16`` qualifier. - """ - - descriptive_name = "tcgen05 F16/BF16 MMA Operation" - - def __init__( - self, - ab_dtype: Type[Numeric], - acc_dtype: Type[Numeric], - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - ab_dtype, - ab_dtype, - acc_dtype, - instruction_shape, - cta_group, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Input data type verification - if self.a_dtype not in [Float16, BFloat16]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", - ) - assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" - # Accumulator data type verification - if self.acc_dtype not in [Float16, Float32]: - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", - ) - # Instruction shape verification - instruction_k = 16 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.a_src._to_ir(), - 0, - ) - return MmaF16BF16Trait( - _cute_nvgpu_ir.make_sm100_mma( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -class MmaF16BF16Trait(MmaTrait): - pass - - -# -# I8 MMA -# - - -@dataclass(frozen=True) -class MmaI8Op(MmaOp): - """ - I8 tcgen05 MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::i8`` qualifier. - """ - - descriptive_name = "tcgen05 I8 MMA Operation" - - def __init__( - self, - ab_dtype: Type[Numeric], - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - ab_dtype, - ab_dtype, - Int32, - instruction_shape, - cta_group, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Input data type verification - if self.a_dtype not in [Int8, Uint8]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8", - ) - assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" - # Instruction shape verification - instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - (T.si8() if self.a_dtype.signed else T.ui8()), - (T.si8() if self.b_dtype.signed else T.ui8()), - T.si32(), - self.a_src._to_ir(), - 0, - ) - return MmaI8Trait( - _cute_nvgpu_ir.make_sm100_mma( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -class MmaI8Trait(MmaTrait): - pass - - -# -# F8F6F4 MMA -# - - -@dataclass(frozen=True) -class MmaFP8Op(MmaOp): - """ - F8 tcgen05 MMA Operation. - - See the `PTX documentation `__. - """ - - descriptive_name = "tcgen05 F8 MMA Operation" - - def __init__( - self, - ab_dtype: Type[Numeric], - acc_dtype: Type[Numeric], - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - - super().__init__( - ab_dtype, - ab_dtype, - acc_dtype, - instruction_shape, - cta_group, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Input data type verification - if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", - ) - assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" - # Accumulator data type verification - if self.acc_dtype not in [Float16, Float32]: - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", - ) - # Instruction shape verification - instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.a_src._to_ir(), - 0, - ) - return MmaFP8Trait( - _cute_nvgpu_ir.make_sm100_mma( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -class MmaFP8Trait(MmaTrait): - pass - - -# -# MXF8F6F4 MMA -# - - -@dataclass(frozen=True) -class MmaMXF8Op(BlockScaledMmaOp): - """ - MXF8 tcgen05 BlockScaled MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. - """ - - descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation" - - def __init__( - self, - ab_dtype: Type[Numeric], - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - ab_dtype, - ab_dtype, - Float32, - Float8E8M0FNU, - 32, - instruction_shape, - cta_group, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Input data type verification - if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", - ) - assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" - # Instruction shape verification - instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.sf_dtype.mlir_type, - self.a_src._to_ir(), - self.sf_vec_size, - ) - return MmaMXF8Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - loc=loc, - ip=ip, - ) - ) - - -class MmaMXF8Trait(BlockScaledMmaTraits): - pass - - -# -# MXF4 MMA -# - - -@dataclass(frozen=True) -class MmaMXF4Op(BlockScaledMmaOp): - """ - MXF4 tcgen05 BlockScaled MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::mxf4`` qualifier. - """ - - descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation" - - def __init__( - self, - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - ) -> None: - super().__init__( - Float4E2M1FN, - Float4E2M1FN, - Float32, - Float8E8M0FNU, - 32, - instruction_shape, - cta_group, - a_src, - OperandMajorMode.K, - OperandMajorMode.K, - ) - self._verify() - - def _verify(self) -> None: - # Instruction shape verification - instruction_k = 64 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.sf_dtype.mlir_type, - self.a_src._to_ir(), - self.sf_vec_size, - ) - return MmaMXF4Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - loc=loc, - ip=ip, - ) - ) - - -class MmaMXF4Trait(BlockScaledMmaTraits): - pass - - -# -# MXF4NVF4 MMA -# - - -@dataclass(frozen=True) -class MmaMXF4NVF4Op(BlockScaledMmaOp): - """ - MXF4NVF4 tcgen05 BlockScaled MMA Operation. - - See the `PTX documentation `__. - This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. - """ - - descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation" - - def __init__( - self, - sf_dtype: Type[Numeric], - instruction_shape: Shape, - cta_group: CtaGroup, - a_src: OperandSource, - ) -> None: - super().__init__( - Float4E2M1FN, - Float4E2M1FN, - Float32, - sf_dtype, - 16, - instruction_shape, - cta_group, - a_src, - OperandMajorMode.K, - OperandMajorMode.K, - ) - self._verify() - - def _verify(self) -> None: - # Scale Factor data type verification - if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]: - raise OpError( - self, - "expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU", - ) - # Instruction shape verification - instruction_k = 64 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( - shape_mnk.type.attribute, - self.cta_group.value, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.sf_dtype.mlir_type, - self.a_src._to_ir(), - self.sf_vec_size, - ) - return MmaMXF4NVF4Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - loc=loc, - ip=ip, - ) - ) - - -class MmaMXF4NVF4Trait(BlockScaledMmaTraits): - pass - -#################################################################################################### -# -# SMEM layout atoms -# -#################################################################################################### - - -class SmemLayoutAtomKind(enum.Enum): - """ - Enum class for the kinds of SMEM layout atoms for SM100. - - Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be - used to construct an SMEM layout using blocked product for operand A or B such that the - resulting layout is legal for both TMA and UMMA. - - Note that there are other ways of creating legal layouts for operand A and B. - """ - - MN_INTER = enum.auto() - MN_SW32 = enum.auto() - MN_SW64 = enum.auto() - MN_SW128 = enum.auto() - MN_SW128_32B = enum.auto() - K_INTER = enum.auto() - K_SW32 = enum.auto() - K_SW64 = enum.auto() - K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py deleted file mode 100644 index c2b3f7cf5b0698752d7ea6c450782f17a3fee797..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .copy import * -from .mma import * - - -# __all__ is required here for documentation generation -__all__ = [ - # mma.py - "MmaF16BF16Op", - # copy.py - "LdMatrix8x8x16bOp", - "LdMatrix16x16x8bOp", - "StMatrix8x8x16bOp", - "StMatrix16x8x8bOp", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py deleted file mode 100644 index a6ad4ca8f0e2dd05b6e779eaedec0b69cd47decf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from dataclasses import dataclass -from typing import Type - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ..common import OpError -from ...core import CopyOp, Trait, _pack_shape -from ...typing import Numeric - - -@dataclass(frozen=True) -class BaseOp(CopyOp): - transpose: bool = False - num_matrices: int = 1 - - def __post_init__(self) -> None: - if not isinstance(self.transpose, bool): - raise OpError( - self, - "expects the 'transpose' Op parameter to be a bool instance", - ) - - def __str__(self) -> str: - res = ( - f"{self.__class__.__name__[:-2]} Copy Operation" - + f"\n number of matrices = {self.num_matrices}" - ) - if self.transpose: - res += f"\n transposed" - return res - - -@dataclass(frozen=True) -class LdMatrix8x8x16bOp(BaseOp): - """ - 8x8 ``ldmatrix`` Operation. - - See the `PTX documentation `__. - This operation corresponds to the ``.m8n8`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.num_matrices not in [1, 2, 4]: - raise OpError( - self, - "expects the 'num_matrices' Op parameter to be one of [1,2,4]", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "LdMatrix8x8x16bTrait": - mode = _pack_shape((8, 8), loc=loc, ip=ip) - ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( - copy_internal_type.mlir_type, - mode.type.attribute, - _cute_nvgpu_ir.LdsmSzPattern.u16, - self.num_matrices, - ir.UnitAttr.get() if self.transpose else None, - ) - return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class LdMatrix8x8x16bTrait(Trait): - pass - - -@dataclass(frozen=True) -class LdMatrix16x16x8bOp(BaseOp): - """ - 16x16 8-bit ``ldmatrix`` Operation. - - See the `PTX documentation `__. - This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers. - """ - - def __init__(self, num_matrices: int) -> None: - super().__init__(transpose=True, num_matrices=num_matrices) - self._verify() - - def _verify(self): - assert self.transpose, "transpose must be True" - if self.num_matrices not in [1, 2]: - raise OpError( - self, - "expects the 'num_matrices' Op parameter to be one of [1,2]", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "LdMatrix16x16x8bTrait": - mode = _pack_shape((16, 16), loc=loc, ip=ip) - ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( - copy_internal_type.mlir_type, - mode.type.attribute, - _cute_nvgpu_ir.LdsmSzPattern.u8, - self.num_matrices, - ir.UnitAttr.get(), - ) - return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class LdMatrix16x16x8bTrait(Trait): - pass - - -@dataclass(frozen=True) -class StMatrix8x8x16bOp(BaseOp): - """ - 8x8 ``stmatrix`` Operation. - - See the `PTX documentation `__. - This operation corresponds to the ``m8n8`` qualifier. - """ - - def __post_init__(self) -> None: - super().__post_init__() - if self.num_matrices not in [1, 2, 4]: - raise OpError( - self, - "expects the 'num_matrices' Op parameter to be one of [1,2,4]", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "StMatrix8x8x16bTrait": - mode = _pack_shape((8, 8), loc=loc, ip=ip) - ty = _cute_nvgpu_ir.CopyAtomStsmType.get( - copy_internal_type.mlir_type, - mode.type.attribute, - self.num_matrices, - ir.UnitAttr.get() if self.transpose else None, - ) - return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class StMatrix8x8x16bTrait(Trait): - pass - - -@dataclass(frozen=True) -class StMatrix16x8x8bOp(BaseOp): - """ - 16x8 ``stmatrix`` Operation. - - See the `PTX documentation `__. - This operation corresponds to the ``m16n8`` qualifier. - """ - - def __init__(self, num_matrices: int) -> None: - super().__init__(transpose=True, num_matrices=num_matrices) - self._verify() - - def _verify(self): - if self.num_matrices not in [1, 2, 4]: - assert self.transpose, "transpose must be True" - raise OpError( - self, - "expects the 'num_matrices' Op parameter to be one of [1,2,4]", - ) - - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ) -> "StMatrix16x8x8bTrait": - mode = _pack_shape((16, 8), loc=loc, ip=ip) - ty = _cute_nvgpu_ir.CopyAtomStsmType.get( - copy_internal_type.mlir_type, - mode.type.attribute, - self.num_matrices, - ir.UnitAttr.get(), - ) - return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - -class StMatrix16x8x8bTrait(Trait): - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py deleted file mode 100644 index 49df213b76f24f23ecfe5a75e36cf17d35aeb98b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from dataclasses import dataclass -from typing import Type - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir - -from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, _Tensor -from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace - - -@dataclass(frozen=True) -class MmaF16BF16Op(MmaOp): - """ - F16/BF16 tcgen05 MMA Operation. - - See the `PTX documentation `__. - This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. - """ - - ab_dtype: Type[Numeric] - acc_dtype: Type[Numeric] - shape_mnk: Shape - - def __post_init__(self) -> None: - if self.ab_dtype not in [Float16, BFloat16]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", - ) - if self.acc_dtype not in [Float16, Float32]: - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", - ) - if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32): - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", - ) - if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]: - raise OpError( - self, - "expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM80Type.get( - shape_mnk.type.attribute, - self.ab_dtype.mlir_type, - self.ab_dtype.mlir_type, - self.acc_dtype.mlir_type, - ) - return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) - - def __str__(self) -> str: - return ( - "warp-level F16/BF16 MMA Operation" - + f"\n A/B data type = {self.ab_dtype}" - + f"\n Accumulator data type = {self.acc_dtype}" - + f"\n Instruction shape MNK = {self.shape_mnk}" - ) - - def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): - pass - - def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): - pass - -class MmaF16BF16Trait(Trait): - pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py deleted file mode 100644 index 49a40165033024c9c9b17acd298a1f8ba055649c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .mma import * -from .helpers import * - -# __all__ is required here for documentation generation -__all__ = [ - # mma.py - "OperandMajorMode", - "OperandSource", - "Field", - "MmaF16BF16Op", - "MmaF8Op", - "SmemLayoutAtomKind", - # helpers.py - "make_smem_layout_atom", - "fence", - "commit_group", - "wait_group", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py deleted file mode 100644 index f6284134933bec170ecec5eeb0bf9f829ef0dff0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Type - -from cutlass.cutlass_dsl import dsl_user_op - -from cutlass._mlir.dialects import nvvm - -from ...typing import Numeric, NumericMeta -from ... import core -from .mma import SmemLayoutAtomKind - - -@dsl_user_op -def make_smem_layout_atom( - kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None -) -> core.ComposedLayout: - """ - Makes a SMEM layout Atom. - - This function creates a composed layout in unit of elements consistent with the requested layout - Atom kind and element data type. - - :param kind: The kind of layout Atom - :type kind: SmemLayoutAtomKind - :param element_type: The element data type to construct the layout for - :type element_type: Type[Numeric] - :return: The SMEM layout atom - :rtype: core.ComposedLayout - """ - if not isinstance(element_type, NumericMeta): - raise TypeError(f"element_type must be a Numeric, but got {element_type}") - - if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): - num_contiguous_bits = 128 - sw = core.make_swizzle(0, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): - num_contiguous_bits = 256 - sw = core.make_swizzle(1, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): - num_contiguous_bits = 512 - sw = core.make_swizzle(2, 4, 3) - elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): - num_contiguous_bits = 1024 - sw = core.make_swizzle(3, 4, 3) - else: - raise ValueError("unrecognized SMEM layout atom kind") - num_contiguous_elems = num_contiguous_bits // element_type.width - - if kind in ( - SmemLayoutAtomKind.MN_INTER, - SmemLayoutAtomKind.MN_SW32, - SmemLayoutAtomKind.MN_SW64, - SmemLayoutAtomKind.MN_SW128, - ): - # M/N-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) - ), - loc=loc, - ip=ip, - ) - else: - # K-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) - ), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def fence(*, loc=None, ip=None) -> None: - """ - See the `PTX documentation `__. - """ - nvvm.wgmma_fence_aligned(loc=None, ip=None) - - -@dsl_user_op -def commit_group(*, loc=None, ip=None) -> None: - """ - See the `PTX documentation `__. - """ - nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip) - - -@dsl_user_op -def wait_group(group, *, loc=None, ip=None) -> None: - """ - See the `PTX documentation `__. - """ - nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py deleted file mode 100644 index 275861f70cc3d6eca932cb263890aaaa4121445f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ /dev/null @@ -1,405 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from dataclasses import dataclass -from typing import Type - -from cutlass.cutlass_dsl import CuTeDSL - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir - -from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor -from ...typing import ( - Shape, - Float16, - BFloat16, - Float32, - Boolean, - Float8E5M2, - Float8E4M3FN, - Numeric, - AddressSpace, -) - - -#################################################################################################### -# -# MMA Ops and Traits -# -#################################################################################################### - - -class OperandMajorMode(enum.Enum): - """ - An enumeration for the majorness of the input operands of the MMA. - """ - - MN = _cute_ir.MajorMode.mn - K = _cute_ir.MajorMode.k - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - @classmethod - def _missing_(cls, value): - if isinstance(value, str): - value = value.upper() - if value == "MN": - return OperandMajorMode.MN - elif value == "K": - return OperandMajorMode.K - - def _to_ir(self) -> _cute_ir.MajorMode: - return self.value - - -class OperandSource(enum.Enum): - """ - An enumeration for the source memory location of the A input operand of the MMA. - """ - - RMEM = _cute_ir.MmaFragKind.rmem - SMEM = _cute_ir.MmaFragKind.smem_desc - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir(self) -> _cute_ir.MmaFragKind: - return self.value - - -class Field(enum.Enum): - """ - An enumeration for the fields of the MMA Atom that can be modified at runtime. - """ - - ACCUMULATE = "accum_c" - - def __str__(self) -> str: - return f"{self.__class__.__name__}.{self.name}" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}.{self.name}>" - - def _to_ir_field_name(self) -> str: - return self.value - - -@dataclass(frozen=True) -class MmaOp(MmaOp): - a_dtype: Type[Numeric] - b_dtype: Type[Numeric] - acc_dtype: Type[Numeric] - shape_mnk: Shape - a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode - - admissible_archs = ["sm_90a"] - - def __post_init__(self) -> None: - # Verify arch - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: - raise OpError( - self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", - suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", - ) - # Verify that the user provided enum values - if not isinstance(self.a_src, OperandSource): - raise OpError( - self, - "expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance", - ) - if not isinstance(self.a_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", - ) - if not isinstance(self.b_major_mode, OperandMajorMode): - raise OpError( - self, - "expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", - ) - # Verify instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): - raise OpError( - self, - f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " - f"but got {self.shape_mnk}", - ) - m, n = self.shape_mnk[0], self.shape_mnk[1] - if m != 64: - raise OpError(self, f"expects the M-mode to be 64, but got {m}") - if (n < 8) or (n > 256) or (n % 8 != 0): - raise OpError( - self, - f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}", - ) - - def __str__(self) -> str: - return ( - self.__class__.descriptive_name # type: ignore - + f"\n A data type = {self.a_dtype}" - + f"\n B data type = {self.b_dtype}" - + f"\n Accumulator data type = {self.acc_dtype}" - + f"\n A source location = {self.a_src}" - + f"\n A major mode = {self.a_major_mode}" - + f"\n B major mode = {self.b_major_mode}" - + f"\n Instruction shape MNK = {self.shape_mnk}" - ) - - def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand A, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): - if input.memspace == AddressSpace.smem and isinstance( - input.layout.type, _cute_ir.ComposedLayoutType - ): - raise OpError( - self, - f"Expected affine layout for {self._make_trait()}'s operand B, " - f"but got composed layout instead: {input.layout}" - f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", - ) - return True - - -class MmaTrait(Trait): - admissible_fields = [Field.ACCUMULATE] - - def set(self, field, value, *, loc=None, ip=None) -> None: - if field not in self.admissible_fields: - raise ValueError( - f"invalid field, must be {Field.ACCUMULATE}, but got {field}" - ) - field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>" - attr = ir.Attribute.parse(field_name) - self.value = _cute_nvgpu_ir.atom_set_value( - self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - - -@dataclass(frozen=True) -class MmaF16BF16Op(MmaOp): - """ - F16/BF16 warpgroup MMA Operation. - - See the `PTX documentation `__. - This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. - """ - - descriptive_name = "warpgroup F16/BF16 MMA Operation" - - def __init__( - self, - ab_dtype: Type[Numeric], - acc_dtype: Type[Numeric], - instruction_shape: Shape, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - ab_dtype, - ab_dtype, - acc_dtype, - instruction_shape, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self) -> None: - # Input data type verification - if self.a_dtype not in [Float16, BFloat16]: - raise OpError( - self, - "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", - ) - assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" - # Accumulator data type verification - if self.acc_dtype not in [Float16, Float32]: - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", - ) - if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32): - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", - ) - # Verify the instruction shape - instruction_k = 16 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( - shape_mnk.type.attribute, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.a_src._to_ir(), - ) - return MmaF16BF16Trait( - _cute_nvgpu_ir.make_sm90_mma( - ty, - Boolean(False).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - ) - - -class MmaF16BF16Trait(MmaTrait): - pass - - -@dataclass(frozen=True) -class MmaF8Op(MmaOp): - """ - F16/BF16 warpgroup MMA Operation. - - See the `PTX documentation `__. - This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands. - """ - - descriptive_name = "warpgroup F8 MMA Operation" - - def __init__( - self, - a_dtype: Type[Numeric], - b_dtype: Type[Numeric], - acc_dtype: Type[Numeric], - instruction_shape: Shape, - a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, - ) -> None: - super().__init__( - a_dtype, - b_dtype, - acc_dtype, - instruction_shape, - a_src, - a_major_mode, - b_major_mode, - ) - self._verify() - - def _verify(self): - # Input data type verification - if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: - raise OpError( - self, - "expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", - ) - if self.b_dtype not in [Float8E5M2, Float8E4M3FN]: - raise OpError( - self, - "expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", - ) - # Accumulator data type verification - if self.acc_dtype not in [Float16, Float32]: - raise OpError( - self, - "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", - ) - # Verify the instruction shape - instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: - raise OpError( - self, - f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", - ) - - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait": - shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) - ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( - shape_mnk.type.attribute, - self.a_major_mode._to_ir(), - self.b_major_mode._to_ir(), - self.a_dtype.mlir_type, - self.b_dtype.mlir_type, - self.acc_dtype.mlir_type, - self.a_src._to_ir(), - ) - return MmaF8Trait( - _cute_nvgpu_ir.make_sm90_mma( - ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip - ) - ) - - -class MmaF8Trait(MmaTrait): - pass - - -#################################################################################################### -# -# SMEM layout atoms -# -#################################################################################################### - - -class SmemLayoutAtomKind(enum.Enum): - """ - Enum class for the kinds of SMEM layout atoms for SM90. - - Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can - be used to construct an SMEM layout using blocked product for operand A or B such that the - resulting layout is legal for both TMA and UMMA. - - Note that there are other ways of creating legal layouts for operand A and B. - """ - - MN_INTER = enum.auto() - MN_SW32 = enum.auto() - MN_SW64 = enum.auto() - MN_SW128 = enum.auto() - K_INTER = enum.auto() - K_SW32 = enum.auto() - K_SW64 = enum.auto() - K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py deleted file mode 100644 index 9128c67a24a7202713c354fb99b2891542f0c887..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import ctypes -from functools import lru_cache -import itertools -import operator -from time import time -from typing import Union - -# MLIR modules imports -from cutlass._mlir import ir -import cutlass._mlir.dialects.cute as _cute_ir - -from cutlass.base_dsl.dsl import is_dynamic_expression -from cutlass.cutlass_dsl import JitArgAdapterRegistry - -# Local modules imports -from .typing import ( - AddressSpace, - Tensor, - Type, - Pointer, - Boolean, - Numeric, - Float4E2M1FN, - Int64, - Int32, - Int16, - Int8, - Uint64, - Uint32, - Uint16, - Uint8, - Float64, - Float32, - Float16, - BFloat16, - Float8E5M2, -) -from . import core -from .core import _Tensor as CoreTensor - - -class _Pointer(Pointer): - """Runtime representation of a pointer that can inter-operate with various data structures, - including numpy arrays and device memory. - - :param pointer: The pointer to the data - :type pointer: int or pointer-like object - :param dtype: Data type of the elements pointed to - :type dtype: Type - :param mem_space: Memory space where the pointer resides, defaults to generic - :type mem_space: _cute_ir.AddressSpace, optional - :param assumed_align: Assumed alignment of input pointer in bytes, defaults to None - :type assumed_align: int, optional - - :ivar _pointer: The underlying pointer - :ivar _dtype: Data type of the elements - :ivar _addr_space: Memory space of the pointer - :ivar _assumed_align: Alignment of the pointer in bytes - :ivar _desc: C-type descriptor for the pointer - :ivar _c_pointer: C-compatible pointer representation - """ - - def __init__( - self, - pointer, - dtype, - mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic, - assumed_align=None, - ): - self._pointer = pointer - self._dtype = dtype - self._addr_space = mem_space - - if assumed_align is None: - self._assumed_align = dtype.width // 8 - else: - self._assumed_align = assumed_align - - self._c_pointer = None - assert ( - int(self._pointer) % self._assumed_align == 0 - ), f"pointer must be {self._assumed_align} bytes aligned" - - def size_in_bytes(self) -> int: - self._desc = ctypes.c_void_p(int(self._pointer)) - return ctypes.sizeof(self._desc) - - def __get_mlir_types__(self): - return [self.mlir_type] - - def __c_pointers__(self): - if self._c_pointer is None: - self._desc = ctypes.c_void_p(int(self._pointer)) - self._c_pointer = ctypes.addressof(self._desc) - return [self._c_pointer] - - def __new_from_mlir_values__(self, values): - assert len(values) == 1 - return values[0] - - def __extract_mlir_values__(self): - return [self._c_pointer] - - # Move mlir Type out of __init__ to decouple with mlir Context - @property - def mlir_type(self) -> ir.Type: - return _cute_ir.PtrType.get( - self._dtype.mlir_type, self._addr_space, self._assumed_align - ) - - @property - def dtype(self) -> Type[Numeric]: - return self._dtype - - @property - def memspace(self): - return self._addr_space - - def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: - raise NotImplementedError("align is not supported in runtime") - - def verify(self, expected_py_type): - if expected_py_type is Pointer: - return True - elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer: - return True - - return False - - def __str__(self) -> str: - return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" - - def __repr__(self): - return self.__str__() - - -class _Tensor(Tensor): - def __init__( - self, - tensor, - assumed_align=None, - ): - # If tensor is already a DLPack object, use it directly - if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"): - self._dlpack_data = tensor - else: - self._dlpack_data = tensor.__dlpack__() - self._dltensor_wrapper = None - self._assumed_align = assumed_align - self._is_dynamic = False - self._memref_desc = None - self._dtype = None - - @property - def __class__(self) -> Type[Tensor]: - # Cheat to let `type(_Tensor())` to return cute.Tensor - return Tensor - - @staticmethod - def lazily_load_dltensor(func): - """Decorator to lazily load the DLTensorWrapper. - - This decorator loads the DLTensorWrapper when needed, - avoiding overhead in the critical path of calling JIT functions. - """ - - def wrapper(self, *args, **kwargs): - if self._dltensor_wrapper is None: - self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data) - return func(self, *args, **kwargs) - - return wrapper - - @lazily_load_dltensor - def mark_layout_dynamic(self, leading_dim: int | None = None): - """Marks the tensor layout as dynamic based on the leading dimension. - - :param leading_dim: The leading dimension of the layout, defaults to None - :type leading_dim: int, optional - - When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout. - The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error - if the layout cannot be automatically deduced. - - When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the - stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent - with the existing layout by checking that the corresponding stride of that dimension is 1. - - Limitation: only support flat layout for now. Will work on supporting nested layout in the future. - - :return: The tensor with dynamic layout - :rtype: _Tensor - """ - self._dltensor_wrapper.mark_layout_dynamic(leading_dim) - return self - - @lazily_load_dltensor - def mark_compact_shape_dynamic( - self, - mode: int, - stride_order: tuple[int, ...] | None = None, - divisibility: int = 1, - ): - """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides. - - :param mode: The mode of the compact shape, defaults to 0 - :type mode: int - :param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None. - Indicates the order of the modes (dimensions) if the current layout were converted to row-major order. - It starts from the outermost to the innermost dimension. - :type stride_order: tuple[int, ...], optional - :param divisibility: The divisibility constraint for the compact shape, defaults to 1 - :type divisibility: int, optional - :return: The tensor with dynamic compact shape - :rtype: _Tensor - - If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout. - Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout). - An error is raised if automatic deduction fails. - - If ``stride_order`` is explicitly specified, it does the consistency check with the layout. - - For example: - - Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`) - - Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`). - - Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor. - .. code-block:: python - a = torch.empty(3, 4) - t = cute.runtime.from_dlpack(a) - t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order()) - """ - self._dltensor_wrapper.mark_compact_shape_dynamic( - mode, stride_order, divisibility - ) - return self - - @property - @lazily_load_dltensor - def element_type(self) -> Type[Numeric]: - if self._dtype is None: - self._dtype = self._dltensor_wrapper.dtype - return self._dtype - - @element_type.setter - def element_type(self, new_type): - """Set the element type of the tensor. - - :warning: This API is added for narrow precision before we have a clean `recast_tensor` story. - - :note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor - from frameworks with storage type like uint8. - - **Example**: - - .. code-block:: python - - # Create a tensor from a numpy array - import numpy as np - from cutlass.cute import from_dlpack - - # Create a tensor with Float32 elements - a = np.zeros(shape, dtype=np.uint8) - tensor = from_dlpack(a) - - # Change the element type to Float4E2M1FN even storage type is uint8 - tensor.element_type = cutlass.Float4E2M1FN - - src = from_dlpack(... data tensor ...) - # convert and initialize narrow precision tensor - cute.testing.convert(src, tensor) - """ - self._dtype = new_type - - @property - @lazily_load_dltensor - def memspace(self): - return self._dltensor_wrapper.address_space - - @property - @lazily_load_dltensor - def size_in_bytes(self) -> int: - return self._dltensor_wrapper.size_in_bytes() - - @property - @lazily_load_dltensor - def mlir_type(self) -> ir.Type: - return self._dltensor_wrapper.get_type( - self.element_type.mlir_type, self._assumed_align - ) - - @lazily_load_dltensor - def __str__(self) -> str: - return f"Tensor<0x{self._dltensor_wrapper.str}>" - - def __repr__(self): - return self.__str__() - - def __setitem__(self, crd, value): - raise TypeError(f"runtime._Tensor is not indexable") - - def __getitem__(self, crd): - raise TypeError(f"runtime._Tensor is not indexable") - - @property - @lazily_load_dltensor - def iterator(self): - return _Pointer( - self._dltensor_wrapper.data_ptr, - self.element_type, - self.memspace, - self._assumed_align, - ) - - @property - def layout(self): - raise NotImplementedError( - f"layout property is not supported in runtime, support in future" - ) - - @property - @lazily_load_dltensor - def shape(self): - return self._dltensor_wrapper.shape - - @property - @lazily_load_dltensor - def stride(self): - strides = self._dltensor_wrapper.stride - if strides is None: - strides = itertools.accumulate( - reversed(self.shape), func=operator.mul, initial=1 - ) - strides = tuple(reversed(list(strides)[:-1])) - - return strides - - @property - @lru_cache(maxsize=128, typed=True) - def leading_dim(self): - """Get the leading dimension of this Tensor. - - :return: The leading dimension index or indices - :rtype: int or tuple or None - - The return value depends on the tensor's stride pattern: - - * If a single leading dimension is found, returns an integer index - * If nested leading dimensions are found, returns a tuple of indices - * If no leading dimension is found, returns None - """ - return core.leading_dim(self.shape, self.stride) - - def fill(self, value: Numeric): - raise TypeError(f"fill function is not supported in runtime") - - @property - @lazily_load_dltensor - def data_ptr(self): - return self._dltensor_wrapper.data_ptr - - @lazily_load_dltensor - def __c_pointers__(self): - self._memref_desc = self._dltensor_wrapper.build_memref_desc( - self._assumed_align - ) - return [_cute_ir.pycapsule_get_pointer(self._memref_desc)] - - def __get_mlir_types__(self): - return [self.mlir_type] - - def __new_from_mlir_values__(self, values): - assert len(values) == 1 - assert isinstance(values[0], CoreTensor) - return CoreTensor(values[0].value, self._dtype) - - -def from_dlpack( - tensor_dlpack, - assumed_align=None, -) -> Tensor: - """Convert from tensor object supporting __dlpack__() to a CuTe Tensor. - - :param tensor_dlpack: Tensor object that supports the DLPack protocol - :type tensor_dlpack: object - :param assumed_align: Assumed alignment of the tensor (bytes), defaults to None, - if None, will use the element size bytes as the assumed alignment. - :type assumed_align: int, optional - :return: A CuTe Tensor object - :rtype: Tensor - - Examples: - .. code-block:: python - - import torch - from cutlass.cute.runtime import from_dlpack - x = torch.randn(100, 100) - y = from_dlpack(x) - y.shape - # (100, 100) - type(y) - # - """ - return _Tensor( - tensor_dlpack, - assumed_align=assumed_align, - ) - - -def make_ptr( - dtype: Type[Numeric], - value: Union[int, ctypes._Pointer], - mem_space: AddressSpace = AddressSpace.generic, - assumed_align=None, -) -> Pointer: - """Create a pointer from a memory address - - :param dtype: Data type of the pointer elements - :type dtype: Type[Numeric] - :param value: Memory address as integer or ctypes pointer - :type value: Union[int, ctypes._Pointer] - :param mem_space: Memory address space, defaults to AddressSpace.generic - :type mem_space: AddressSpace, optional - :param align_bytes: Alignment in bytes, defaults to None - :type align_bytes: int, optional - :return: A pointer object - :rtype: Pointer - - .. code-block:: python - - import numpy as np - import ctypes - - from cutlass import Float32 - from cutlass.cute.runtime import make_ptr - - # Create a numpy array - a = np.random.randn(16, 32).astype(np.float32) - - # Get pointer address as integer - ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) - - # Create pointer from address - y = make_ptr(cutlass.Float32, ptr_address) - - # Check properties - print(y.element_type) - print(type(y)) # - """ - # check if value is int or ctypes.POINTER - if isinstance(value, int): - address_value = value - elif isinstance(value, ctypes._Pointer): - # get address value - address_value = ctypes.cast(value, ctypes.c_void_p).value - assert address_value is not None, "Pointer address is None" - else: - raise TypeError( - f"Expect int or ctypes.POINTER for value but got {type(value)=}" - ) - - return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align) - - -class TensorAdapter: - """ - Convert a DLPack protocol supported tensor/array to a cute tensor. - """ - - def __init__(self, arg): - self._arg = from_dlpack(arg).mark_layout_dynamic() - - def __new_from_mlir_values__(self, values): - return self._arg.__new_from_mlir_values__(values) - - def __c_pointers__(self): - return self._arg.__c_pointers__() - - def __get_mlir_types__(self): - return self._arg.__get_mlir_types__() - - -# ------------------------------------------------------------------------- -# Try to register_jit_arg_adapter for TensorAdapter -# ------------------------------------------------------------------------- - -try: # Register for numpy.ndarray - import numpy - - JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter) -except ImportError: - pass # silent attempt, suppress error - -try: # Register for torch.Tensor - import torch - - JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter) -except ImportError: - pass # silent attempt, suppress error diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py deleted file mode 100644 index 88e0da048fc951da5091bcc38a6e6c92164f6d04..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py +++ /dev/null @@ -1,610 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import functools -import inspect -import logging -import os -from enum import Enum -from inspect import isclass -from itertools import product -from time import time -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import cuda.bindings.driver as cuda_driver -import cuda.bindings.runtime as cuda_runtime -import numpy as np - -import cutlass._mlir.ir as ir -import cutlass.base_dsl.jit_executor -import cutlass.cute as cute -from cutlass._mlir.dialects import builtin, cf, nvvm, vector -from cutlass.cute import core, nvgpu -from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op - - -@dsl_user_op -def assert_(cond, msg=None, *, loc=None, ip=None): - cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) - - -def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): - if src.element_type.width == 4: - tv_layout = core.recast_layout(8, 4, tv_layout) - src = core.recast_tensor(src, dtype=t.Int8) - return src, tv_layout - - -def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]): - """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit. - - :param input: The input tensor to recast. - :param dtype: The target numeric type to potentially recast to. - :raises TypeError: If dtype is not a subclass of Numeric. - :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged. - """ - if not isclass(dtype) or not issubclass(dtype, core.Numeric): - raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}") - - if dtype.width == 4: - recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape - i4_vec = vector.bitcast( - T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast() - ) - res_vect = builtin.unrealized_conversion_cast( - [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec] - ) - return core.TensorSSA(res_vect, recast_shape, dtype) - return input - - -def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]): - """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit. - - :param input: The input tensor to recast. - :param src_dtype: The source numeric type to potentially recast from. - :raises TypeError: If src_dtype is not a subclass of Numeric. - :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged. - """ - if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric): - raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}") - - if src_dtype.width == 4: - recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape - i4_vec = builtin.unrealized_conversion_cast( - [T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()] - ) - res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec) - return core.TensorSSA(res_vect, recast_shape, core.Int8) - return input - - -@CuTeDSL.kernel -def _convert_kernel( - gSrc: core.Tensor, - gDst: core.Tensor, - cSrc: core.Tensor, - src_tv_layout: core.Layout, - dst_tv_layout: core.Layout, - src_shape: core.Shape, - src_ty, - dst_ty, -): - tidx = nvvm.read_ptx_sreg_tid_x(T.i32()) - bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32()) - - cta_coord = (None, bidx) - # logical idx -> address - ctaSrc = gSrc[cta_coord] # (...,TileV,...) - ctaDst = gDst[cta_coord] # (...,TileV,...) - ctaCSrc = cSrc[cta_coord] # (...,TileV,...) - # print(f"ctaSrc = {ctaSrc.type}") - - # compose with CTA TV layout - # tid, vid -> address - tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V) - tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V) - tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V) - # print(f"tidfrgSrc = {tidfrgSrc.type}") - - # slice for threads - thr_coord = (tidx, None) - thrSrc = tidfrgSrc[thr_coord] # (V) - thrDst = tidfrgDst[thr_coord] # (V) - thrCSrc = tidfrgCSrc[thr_coord] # (V) - # print(f"thrSrc = {thrSrc.type}") - - # predicate - if core.elem_less(thrCSrc[0], src_shape): - # allocate fragments for gmem->rmem - frgSrc = core.make_fragment( - core.get(src_tv_layout, mode=[1]), gSrc.element_type - ) # (V) - frgDst = core.make_fragment( - core.get(dst_tv_layout, mode=[1]), gDst.element_type - ) # (V) - # print(f"frgSrc = {frgSrc.type}") - - # Move data to reg address space - copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type) - core.copy(copy_atom_load, thrSrc, frgSrc) - - vec_src = frgSrc.load() - vec_src = _maybe_recast_to_f4(vec_src, src_ty) - vec_dst = vec_src.to(dst_ty) - vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty) - frgDst.store(vec_dst) - - # Copy the results back to c - copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type) - core.copy(copy_atom_stg, frgDst, thrDst) - - -@CuTeDSL.jit(preprocess=False) -def _convert( - src: core.Tensor, - dst: core.Tensor, - leading_mode: Constexpr, - elem_per_copy: Constexpr, -): - - # Step 1. figure proper tv_layout - src_ty = src.element_type - dst_ty = dst.element_type - - tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) - - # Step 2. maybe recast from f4 tensor - src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout) - dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout) - src_shape = src.shape - # predicate tensor - idA = core.make_identity_tensor(src.shape) - - # Step 3. select a proper tiling pattern as (...,TileV, ...) - src_cta_tiler = [ - 1, - ] * core.rank(src.layout) - src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...) - dst_cta_tiler = [ - 1, - ] * core.rank(dst.layout) - dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...) - - # Step 4. partition input and output tensor by cta tiler. - gS = core.zipped_divide( - src, tuple(src_cta_tiler) - ) # ((...,TileV,...),(...,RestV,...)) - cS = core.zipped_divide( - idA, tuple(src_cta_tiler) - ) # ((...,TileV,...),(...,RestV,...)) - gD = core.zipped_divide( - dst, tuple(dst_cta_tiler) - ) # ((...,TileV,...),(...,RestV,...)) - # print(f"{gS.type=}") - - _convert_kernel( - gS, - gD, - cS, - src_tv_layout, - dst_tv_layout, - src_shape, - src_ty, - dst_ty, - ).launch( - grid=[core.size(gS, mode=[1]), 1, 1], - block=[core.size(src_tv_layout, mode=[0]), 1, 1], - ) - - -# Converts from src tensor to dst tensor, their logical shape are required to be the same. -# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of -# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext -# needs 32-bits aligned input/output) -def convert(src: core.Tensor, dst: core.Tensor): - assert len(src.shape) == len( - dst.shape - ), "Shape of src and dst tensors should be the same rank." - # find leading mode - leading_mode = [ - idx - for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) - if shape > 1 and stride == 1 - ] - if len(leading_mode) != 1: - raise ValueError(f"Leading mode should be unique, but got {leading_mode}") - leading_mode = leading_mode[0] - - elem_per_copy = 2 - - if src.element_type.width == 4 or dst.element_type.width == 4: - elem_per_copy = 8 - elif src.element_type.width == 8 or dst.element_type.width == 8: - elem_per_copy = 4 - assert ( - src.shape[leading_mode] % elem_per_copy == 0 - and dst.shape[leading_mode] % elem_per_copy == 0 - ) - _convert(src, dst, leading_mode, elem_per_copy) - - -######################################### -# Testing utilities -######################################### - - -def sample_pytest(rand_cfg=None): - """ - Decorator to randomly sample pytest parametrized tests. - rand_cfg: Tuple[int, float] - (random_seed, sample_ratio) - Sampling is disabled when: - - A specific test is selected (via -k or direct test path) - - Not running under pytest - """ - import functools - import os - import random - import sys - - import pytest - - seed, sample_ratio = rand_cfg - random.seed(seed) - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ: - # Check if test was explicitly selected like ::test_name[param1-param2-...] - if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv): - # Test was explicitly selected, don't skip - return func(*args, **kwargs) - - if random.uniform(0.0, 1.0) > sample_ratio: - pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})") - return func(*args, **kwargs) - - return wrapper - - return decorator - - -######################################### -# Benchmarking utilities -######################################### - - -class JitArguments: - """ - A type to hold both args and kwargs for passing to a kernel while benchmarking. - """ - - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - -def _cuda_success( - err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str -): - """ - Helper function to check CUDA API errors. - """ - if isinstance(err, tuple): - _cuda_success(err[0], message) - elif isinstance(err, cuda_runtime.cudaError_t): - error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8") - if err != cuda_runtime.cudaError_t.cudaSuccess: - raise RuntimeError(f"{message} : {error_message}") - elif isinstance(err, cuda_driver.CUresult): - if err != cuda_driver.CUresult.CUDA_SUCCESS: - error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8") - raise RuntimeError(f"{message} : {error_message}") - else: - raise TypeError( - f"{err} is an unexpected type : it should be a cudaError_t or CUresult" - ) - - -def _does_kernel_use_stream( - kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs -): - """ - This function checks if the kernel uses the provided non-default stream. - It does this by capturing the stream and then checking if any kernels were launched. - :param kernel: The kernel to check - :type kernel: Callable - :param stream: The stream to check - :type stream: cuda_driver.CUstream - :return: True if the kernel uses the stream, False otherwise - :rtype: bool - """ - - assert int(stream) != int( - cuda_driver.CUstream_flags.CU_STREAM_DEFAULT - ), "Stream must be a non-default stream" - - err = cuda_runtime.cudaStreamBeginCapture( - stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal - ) - _cuda_success(err, "Error on stream capture") - - kernel(*args, **kwargs) - - err, graph = cuda_runtime.cudaStreamEndCapture(stream) - _cuda_success(err, "Error on stream capture") - - # Get number of nodes in warmup graph to check it matches what is expected - err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph) - _cuda_success(err, "Error on querying graph") - return num_nodes > 0 - - -def benchmark( - callable: Callable, - *, - warmup_iterations: int = 10, - iterations: int = 100, - stream: Optional[cuda_driver.CUstream] = None, - kernel_arguments: Optional[JitArguments] = None, - workspace_generator: Optional[Callable[[], JitArguments]] = None, - workspace_count: int = 1, - use_cuda_graphs: bool = False, -) -> float: - """Benchmarks a callable function with the specified parameters. - - For example, - .. code-block:: python - - from cutlass.cute.testing import benchmark - - @cute.jit - def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream): - # contents of the function - pass - - time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream) - warmup_iterations=10, iterations=100 - stream=stream) - - To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator - parameters to cycle through a number of different workspaces. - - .. code-block:: python - - from cutlass.cute.testing import benchmark - - @cute.jit - def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): - # contents of the function - pass - - def workspace_generator(): - # create a, b, and c - return JitArguments(a, b, c) - - time_us = benchmark(user_function, - workspace_generator=workspace_generator, - workspace_count=10, - warmup_iterations=10000, - iterations=1000) - - To benchmark you may always configure the function being profiled (callable), the warmup iterations, and - the number of profiling iterations. - - Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter. - - To use CUDA graphs, the callable must be a compiled @cute.jit annotated function. - When using CUDA graphs, the kernel must be launched in a non-default stream. - - :param callable: The function to benchmark - :type callable: Callable - :param warmup_iterations: Number of warmup iterations, defaults to 10 - :type warmup_iterations: int, optional - :param iterations: Number of benchmark iterations, defaults to 100 - :type iterations: int, optional - :param stream: Stream kernel is launched in, defaults to CUDA stream default - :type stream: CUstream, None - :param kernel_arguments: Kernel arguments to launch callable with, defaults to None - :type kernel_arguments: JitArguments, None - :param workspace_generator: Function that returns kernel arguments, defaults to None - :type workspace_generator: Callable - :param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold - :type workspace_count: int, optional - :param use_cuda_graphs: Whether to use cuda graphs, defaults to False - :type use_cuda_graphs: bool, optional - - :return: The benchmark time in microseconds - :rtype: float - """ - - if stream is None: - stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) - - if workspace_count < 1: - raise ValueError("workspace_count must be at least 1") - - time_us = float("nan") - if workspace_generator == None: - # If no workspace generator is provided, we need a single workspace - if workspace_count != 1: - raise ValueError("Need a single workspace if not providing a generator") - - # If no workspace generator is provided, we need a kernel_argument - if kernel_arguments == None: - raise ValueError( - "Please pass a kernel argument if not providing a generator" - ) - workspace_generator = lambda: kernel_arguments - - workspaces = [workspace_generator() for _ in range(workspace_count)] - - for workspace in workspaces: - if type(workspace) != JitArguments: - raise TypeError( - "workspace_generator and/or kernel_arguments should use JitArguments type" - ) - - def _loop_and_call_kernel(iterations: int, workspace_index: int = 0): - for _ in range(iterations): - current_workspace = workspaces[workspace_index] - callable(*current_workspace.args, **current_workspace.kwargs) - workspace_index = (workspace_index + 1) % workspace_count - return workspace_index - - # Create CUDA events for timing - err, start_event = cuda_driver.cuEventCreate( - cuda_driver.CUevent_flags.CU_EVENT_DEFAULT - ) - _cuda_success(err, "Error on creating event") - err, end_event = cuda_driver.cuEventCreate( - cuda_driver.CUevent_flags.CU_EVENT_DEFAULT - ) - _cuda_success(err, "Error on creating event") - - elapsed_time = float("nan") - - if use_cuda_graphs: - # Check if the callable is a JitExecutor - if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor): - raise TypeError("Function must be precompiled to be used with CUDA Graphs") - - # Check if the stream is a non-default stream - if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT): - raise ValueError( - "Measuring with CUDA Graphs requires executing in a non-default stream" - ) - - workspace_index = 0 - - # Capture warmup graph - err = cuda_runtime.cudaStreamBeginCapture( - stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal - ) - _cuda_success(err, "Error on stream capture") - - workspace_index = _loop_and_call_kernel(warmup_iterations) - err, gwarm = cuda_runtime.cudaStreamEndCapture(stream) - _cuda_success(err, "Error on stream capture") - - # Get number of nodes in warmup graph to check it matches what is expected - err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm) - _cuda_success(err, "Error on querying graph") - # Assertion is >= since we may launch multiple kernels in one host function - if num_nodes < warmup_iterations: - raise ValueError( - f"CUDA stream passed to benchmark does not match the stream the kernel was launched in" - ) - - # Capture profiling graph - err = cuda_runtime.cudaStreamBeginCapture( - stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal - ) - _cuda_success(err, "Error on stream capture") - _loop_and_call_kernel(iterations, workspace_index) - err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) - _cuda_success(err, "Error on stream capture") - - # Instantiate graphs - err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0) - _cuda_success(err, "Error on graph instantiation") - err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0) - _cuda_success(err, "Error on graph instantiation") - - # Launch warmup graph - err = cuda_runtime.cudaGraphLaunch(gwarm, stream) - _cuda_success(err, "Error on graph launch") - - # Record start time - err = cuda_driver.cuEventRecord(start_event, stream) - _cuda_success(err, "Error on recording event") - - # Launch profiling graph - err = cuda_runtime.cudaGraphLaunch(gprofile, stream) - _cuda_success(err, "Error on graph launch") - - # Record end time - err = cuda_driver.cuEventRecord(end_event, stream) - _cuda_success(err, "Error on recording event") - err = cuda_driver.cuEventSynchronize(end_event) - _cuda_success(err, "Error on synchronizing event") - - # Get elapsed time - err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) - _cuda_success(err, "Error on querying event") - - # Destroy graphs - err = cuda_runtime.cudaGraphExecDestroy(gwarm) - _cuda_success(err, "Error on destroying graph") - err = cuda_runtime.cudaGraphExecDestroy(gprofile) - _cuda_success(err, "Error on destroying graph") - - else: - - if int(stream) != int( - cuda_driver.CUstream_flags.CU_STREAM_DEFAULT - ) and not _does_kernel_use_stream( - callable, stream, *workspaces[0].args, **workspaces[0].kwargs - ): - raise ValueError( - "CUDA stream passed to benchmark does not match the stream the kernel was launched in" - ) - - # Not using graphs - # Warmup - workspace_index = _loop_and_call_kernel(warmup_iterations) - # Record start event - err = cuda_driver.cuEventRecord(start_event, stream) - _cuda_success(err, "Error on recording event") - _loop_and_call_kernel(iterations, workspace_index) - # Record end event - err = cuda_driver.cuEventRecord(end_event, stream) - _cuda_success(err, "Error on recording event") - # Synchronize end event - err = cuda_driver.cuEventSynchronize(end_event) - _cuda_success(err, "Error on synchronizing event") - err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) - _cuda_success(err, "Error on querying event") - - # Destroy events - err = cuda_driver.cuEventDestroy(start_event) - _cuda_success(err, "Error on destroying event") - err = cuda_driver.cuEventDestroy(end_event) - _cuda_success(err, "Error on destroying event") - - return elapsed_time / iterations * 1e3 - - -def get_workspace_count( - one_workspace_bytes: int, warmup_iterations: int, iterations: int -) -> int: - """Calculate the number of workspaces needed to fill L2 cache. - - :param one_workspace_bytes: Size of one workspace in bytes - :type one_workspace_bytes: int - :param warmup_iterations: Number of warmup iterations - :type warmup_iterations: int - :param iterations: Number of iterations - :type iterations: int - :return: Number of workspaces needed - :rtype: int - """ - num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes() - return max( - 1, - min( - warmup_iterations + iterations, # Don't create more workspaces than needed - (num_l2_cache_bytes + one_workspace_bytes - 1) - // one_workspace_bytes, # Ceiling division - ), - ) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py deleted file mode 100644 index 215e71d98fc39c192c784c99bb8ef14f6e2f55d9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py +++ /dev/null @@ -1,207 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from abc import ABC, abstractmethod -from typing import ForwardRef, Tuple, Union, Any, Type, List - -from cutlass.base_dsl.typing import * - -from cutlass._mlir import ir -import cutlass._mlir.extras.types as T -from cutlass._mlir.dialects.cute import AddressSpace - - -Int = Union[int, Integer] - - -ScaledBasis = ForwardRef("ScaledBasis") - - -IntTuple = Union[Int, Tuple["IntTuple", ...]] -Shape = Union[Int, Tuple["Shape", ...]] -Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]] -Coord = Union[Int, None, Tuple["Coord", ...]] - - -class Layout(ir.Value): - def __init__(self, op_result): - super().__init__(op_result) - - def __str__(self): ... - - def get_hier_coord(self, idx) -> Coord: - """Return the (hierarchical) ND logical coordinate corresponding to the linear index""" - ... - - @property - def shape(self, *, loc=None, ip=None) -> Shape: ... - - @property - def stride(self, *, loc=None, ip=None) -> Stride: ... - - -Tile = Union[Int, None, Layout, Tuple["Tile", ...]] - -# XTuple is super set of above types -XTuple = Union[IntTuple, Shape, Stride, Coord, Tile] - -Tiler = Union[Shape, Layout, Tile] - - -class Pointer(ABC): - """ - Abstract base class for CuTe jit function and runtime _Pointer - """ - - @property - def value_type(self) -> Type[Numeric]: - return self.dtype - - @property - def dtype(self) -> Type[Numeric]: ... - - def align(self, min_align: int) -> "Pointer": ... - - def __get_mlir_types__(self) -> List[ir.Type]: ... - - def __extract_mlir_values__(self) -> List[ir.Value]: ... - - def __new_from_mlir_values__(self, values) -> "Pointer": ... - - -class Tensor(ABC): - """ - Abstract base class for CuTe jit function and runtime _Tensor - - A CuTe Tensor is iterator with layout - - :Examples: - - Create tensor from torch.tensor with Host Runtime: - - .. code-block:: python - - >>> import torch - >>> from cutlass.cute.runtime import from_dlpack - >>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32)) - >>> mA.shape - (3,) - >>> mA.stride - (1,) - >>> mA.layout - (3,):(1,) - - Define JIT function: - - .. code-block:: python - - @cute.jit - def add(a: Tensor, b: Tensor, res: Tensor): ... - - Call JIT function from python: - - .. code-block:: python - - >>> import torch - >>> a = torch.tensor([1, 3, 5], dtype=torch.int32) - >>> b = torch.tensor([2, 4, 6], dtype=torch.int32) - >>> c = torch.zeros([3], dtype=torch.int32) - >>> mA = from_dlpack(a) - >>> mB = from_dlpack(b) - >>> mC = from_dlpack(c) - >>> add(mA, mB, mC) - >>> c - tensor([3, 7, 11], dtype=torch.int32) - """ - - def __str__(self): ... - - @abstractmethod - def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ... - - @abstractmethod - def __setitem__(self, idx, value): ... - - @property - @abstractmethod - def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ... - - @element_type.setter - def element_type(self, new_type): ... - - @property - @abstractmethod - def memspace(self) -> AddressSpace: ... - - @property - @abstractmethod - def iterator(self): ... - - @property - def layout(self) -> Union[Layout, "ComposedLayout"]: ... - - @property - def shape(self) -> Shape: ... - - def load(self, *, loc=None, ip=None) -> "TensorSSA": ... - - def store(self, data: "TensorSSA", *, loc=None, ip=None): ... - - def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ... - - def mark_compact_shape_dynamic( - self, - mode: int, - stride_order: tuple[int, ...] | None = None, - divisibility: int = 1, - ) -> "Tensor": ... - - @abstractmethod - def fill(self, value: Numeric) -> None: ... - - -__all__ = [ - "Coord", - "Numeric", - "Integer", - "Boolean", - "Int8", - "Int16", - "Int32", - "Int64", - "Uint8", - "Uint16", - "Uint32", - "Uint64", - "Float", - "Float16", - "BFloat16", - "TFloat32", - "Float32", - "Float64", - "Float8E5M2", - "Float8E4M3FN", - "Float8E4M3B11FNUZ", - "Float8E4M3", - "Float8E8M0FNU", - "Float4E2M1FN", - "Float6E2M3FN", - "Float6E3M2FN", - "IntTuple", - "Layout", - "Pointer", - "Shape", - "Stride", - "Tensor", - "Tile", - "Tiler", - "XTuple", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py deleted file mode 100644 index 0bb9b5207144a11665449fac431fcbe2bd8f49bd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - - -def check_value_in( - value, possible_values: list, value_description: str, prefix="" -) -> None: - if value not in possible_values: - err_msg = prefix - if err_msg != "": - err_msg += ": " - err_msg += f"invalid {value_description}, got {value}, must be one of {possible_values}" - raise ValueError(err_msg) - - -def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None: - if not isinstance(ty, type): - ty = type(ty) - if ty not in possible_types: - err_msg = prefix - if err_msg != "": - err_msg += ": " - err_msg += f"invalid type for {type_description}, got {ty}, must be one of {possible_types}" - raise TypeError(err_msg) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py deleted file mode 100644 index 7df24dd6bb6a5e42ebf5bad0e785cf77589bbbc6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .helpers import ( - Agent, - CooperativeGroup, - PipelineOp, - SyncObject, - MbarrierArray, - NamedBarrier, - TmaStoreFence, - PipelineUserType, - PipelineState, - make_pipeline_state, - pipeline_init_wait, - arrive, - arrive_unaligned, - wait, - wait_unaligned, - arrive_and_wait, - sync, -) - -from .sm90 import ( - PipelineAsync, - PipelineCpAsync, - PipelineTmaAsync, - PipelineTmaMultiConsumersAsync, - PipelineTmaStore, - PipelineProducer, - PipelineConsumer, -) - -from .sm100 import ( - PipelineTmaUmma, - PipelineAsyncUmma, - PipelineUmmaAsync, -) - -__all__ = [ - "Agent", - "CooperativeGroup", - "PipelineOp", - "SyncObject", - "MbarrierArray", - "NamedBarrier", - "TmaStoreFence", - "PipelineUserType", - "PipelineState", - "PipelineAsync", - "PipelineCpAsync", - "PipelineTmaAsync", - "PipelineTmaUmma", - "PipelineTmaMultiConsumersAsync", - "PipelineAsyncUmma", - "PipelineUmmaAsync", - "PipelineTmaStore", - "PipelineProducer", - "PipelineConsumer", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py deleted file mode 100644 index b5b94899435224ceda4bd152944e9a4b9bc2e911..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py +++ /dev/null @@ -1,652 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Optional, Union -import warnings - -import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate -from cutlass._mlir.dialects import llvm -import cutlass._mlir.dialects.cute as _cute_ir - - -############################################################################## -# Agent class -############################################################################## - - -class Agent(enum.Enum): - """ - Agent indicates what is participating in the pipeline synchronization. - """ - - # Arbitrary grouping of N threads - Thread = enum.auto() - # Same as AsyncThread, but includes all threads in the block - ThreadBlock = enum.auto() - # Same as AsyncThread, but includes all threads in the cluster - ThreadBlockCluster = enum.auto() - - -class CooperativeGroup: - """ - CooperativeGroup contains size and alignment restrictions for an Agent. - """ - - def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): - if agent is Agent.Thread: - assert size > 0 - if size == 32: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warp." - elif size == 128: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warpgroup." - elif agent is Agent.ThreadBlock: - raise NotImplementedError("Error: Not yet supported.") - elif agent is Agent.ThreadBlockCluster: - raise NotImplementedError("Error: Not yet supported.") - else: - # Should never reach this state - size = 0 - - if size <= 0: - raise ValueError( - "Error: The number of threads in a CooperativeGroup must be more than 0." - ) - - # Size indicates how many threads are participating in this CooperativeGroup - self.size = size - # Agent indicates the type of thread group - self.agent = agent - - -class PipelineOp(enum.Enum): - """ - PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. - """ - - # async-threads - AsyncThread = enum.auto() - # Blackwell (SM100a) MMA instruction - TCGen05Mma = enum.auto() - # Tensor Memory Accelerator load - TmaLoad = enum.auto() - # TMA Store consuming smem produced by AsyncThread - TmaStore = enum.auto() - # Composite of multiple PipelineOps - Composite = enum.auto() - # Async load without TMA - AsyncLoad = enum.auto() - - -def _get_pipeline_op(type_str): - return PipelineOp(type_str) - - -############################################################################## -# SyncObject class -############################################################################## - - -class SyncObject(ABC): - """Abstract base class for hardware synchronization primitives. - - This class defines the interface for different types of hardware synchronization - mechanisms including shared memory barriers, named barriers, and fences. - """ - - @abstractmethod - def arrive(self) -> None: - pass - - @abstractmethod - def wait(self) -> None: - pass - - @abstractmethod - def arrive_and_wait(self) -> None: - pass - - @abstractmethod - def arrive_and_drop(self) -> None: - pass - - @abstractmethod - def get_barrier(self) -> Union[cute.Pointer, int, None]: - pass - - @abstractmethod - def max(self) -> Union[int, None]: - pass - - -class MbarrierArray(SyncObject): - """ - MbarrierArray implements an abstraction for an array of smem barriers. - """ - - def __init__( - self, - barrier_storage: cute.Pointer, - num_stages: int, - agent: tuple[PipelineOp, CooperativeGroup], - tx_count: int = 0, - ) -> None: - self.barrier_storage = barrier_storage - self.tx_count = tx_count - self.num_stages = num_stages - self.op_type, self.cg = agent - self.arrive_count = self.cg.size - - if self.num_stages <= 0: - raise ValueError("Error: Mbarrier stage count must be greater than 0.") - if self.arrive_count <= 0: - raise ValueError("Error: Mbarrier arrive count must be greater than 0.") - if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0: - raise ValueError( - "Error: Mbarrier tx count must not be less than 0 for TMA ops." - ) - - # Store mbarrier base pointer - self.mbarrier_base = self.barrier_storage - - # Mbarrier initialization in constructor - self.mbarrier_init() - - def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray": - """ - Creates a copy of MbarrierArray with a different op_type without re-initializing barriers - """ - # Create new instance without initialization - new_mbarrier_array = object.__new__(MbarrierArray) - - # Copy all attributes directly - new_mbarrier_array.barrier_storage = self.barrier_storage - new_mbarrier_array.op_type = new_op_type - new_mbarrier_array.cg = self.cg - new_mbarrier_array.num_stages = self.num_stages - new_mbarrier_array.tx_count = self.tx_count - new_mbarrier_array.arrive_count = self.arrive_count - new_mbarrier_array.mbarrier_base = self.mbarrier_base - return new_mbarrier_array - - # Mbarrier initialization - def mbarrier_init(self) -> None: - """ - Initializes an array of mbarriers using warp 0. - """ - - def then_body(): - for index in range(self.num_stages): - cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count) - - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - - if_generate(warp_idx == 0, then_body) - - def arrive( - self, - index: int, - dst: int, - cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, - ) -> None: - """Select the arrive corresponding to this MbarrierArray's PipelineOp. - - :param index: Index of the mbarrier in the array to arrive on - :type index: int - :param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. - When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier. - - For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs - in the cluster with rank = 0, 1, and 3). - - For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on - the mbarrier with rank = 3 in the cluster). - :type dst: int | None - :param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types - :type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional - """ - if self.op_type is PipelineOp.AsyncThread: - self.arrive_mbarrier(index, dst) - elif self.op_type is PipelineOp.TCGen05Mma: - assert ( - cta_group is not None - ), "Error: CTA group must be provided for TCGen05Mma." - self.arrive_tcgen05mma(index, dst, cta_group) - elif self.op_type in [PipelineOp.TmaLoad]: - self.arrive_and_expect_tx(index, self.tx_count) - elif self.op_type is PipelineOp.AsyncLoad: - self.arrive_cp_async_mbarrier(index) - else: - assert ( - False - ), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." - - def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None: - if dst_rank is None: - cute.arch.mbarrier_arrive(self.get_barrier(index)) - else: - cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) - - def arrive_cp_async_mbarrier(self, index: int): - cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index)) - - def arrive_tcgen05mma( - self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup - ) -> None: - if mask is None: - with cute.arch.elect_one(): - cute.nvgpu.tcgen05.commit(self.get_barrier(index)) - else: - with cute.arch.elect_one(): - cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group) - - def arrive_and_expect_tx(self, index: int, tx_count: int) -> None: - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count) - - def try_wait(self, index: int, phase: int) -> Boolean: - return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) - - def wait(self, index: int, phase: int) -> None: - cute.arch.mbarrier_wait(self.get_barrier(index), phase) - - def arrive_and_wait( - self, - index: int, - phase: int, - dst: int, - cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, - ) -> None: - arrive(index, dst, cta_group) - wait(index, phase) - - def arrive_and_drop(self) -> None: - raise NotImplementedError("Error: Not yet supported.") - - def get_barrier(self, index: int) -> cute.Pointer: - return self.mbarrier_base + index - - def max(self) -> int: - # Transaction barriers have a maximum arrive count of 511 (2^9 - 1). - # Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1). - return 511 - - def __extract_mlir_values__(self): - return [self.barrier_storage] - - def __new_from_mlir_values__(self, values): - return MbarrierArray( - values[0], self.num_stages, (self.op_type, self.cg), self.tx_count - ) - - -@dataclass(frozen=True) -class NamedBarrier(SyncObject): - """ - NamedBarrier is an abstraction for named barriers managed by hardware. - There are 16 named barriers available, with barrier_ids 0-15. - - See the `PTX documentation `__. - """ - - barrier_id: int - num_threads: int - - def __post_init__(self) -> None: - if self.barrier_id < 0 or self.barrier_id >= 16: - raise ValueError("Error: NamedBarrier ID must be between 0 and 16.") - if self.barrier_id == 0: - warnings.warn( - "NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used." - ) - - def arrive(self) -> None: - """ - The aligned flavor of arrive is used when all threads in the CTA will execute the - same instruction. See PTX documentation. - """ - cute.arch.barrier_arrive( - barrier_id=self.barrier_id, number_of_threads=self.num_threads - ) - - def arrive_unaligned(self) -> None: - """ - The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. - """ - llvm.inline_asm( - None, - [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], - "barrier.arrive $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - def wait(self) -> None: - """ - NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. - If synchronizing two warps in a producer/consumer pairing, the arrive count would be - 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer - or consumer are counted for mbarriers, while all threads participating in the sync - are counted for NamedBarriers. - """ - warnings.warn( - "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." - ) - self.arrive_and_wait() - - def wait_unaligned(self) -> None: - warnings.warn( - "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." - ) - llvm.inline_asm( - None, - [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], - "barrier.sync $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - def arrive_and_wait(self) -> None: - cute.arch.barrier( - barrier_id=self.barrier_id, number_of_threads=self.num_threads - ) - - def arrive_and_drop(self) -> None: - raise NotImplementedError("Error: Not supported.") - - def sync(self) -> None: - cute.arch.barrier(barrier_id=self.barrier_id) - - def get_barrier(self) -> int: - return self.barrier_id - - def max(self) -> int: - # Transaction barriers have a maximum arrive count of 4095 (2^12 - 1). - return 4095 - - -class TmaStoreFence(SyncObject): - """ - TmaStoreFence is used for a multi-stage epilogue buffer. - """ - - def __init__(self, num_stages: int = 0) -> None: - if num_stages <= 0: - raise ValueError("Mbarrier stage count must be greater than 0.") - - self.num_stages = num_stages - - def arrive(self) -> None: - cute.arch.cp_async_bulk_commit_group() - - def wait(self) -> None: - cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) - - def arrive_and_wait(self) -> None: - self.arrive() - self.wait() - - def arrive_and_drop(self) -> None: - raise NotImplementedError("Error: Not supported.") - - # TmaStoreFence doesn't have mbarriers - def get_barrier(self) -> None: - assert ( - False - ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." - - def max(self) -> None: - raise NotImplementedError("Error: Not supported.") - - def tail(self) -> None: - cute.arch.cp_async_bulk_wait_group(0, read=True) - - -############################################################################## -# PipelineState class -############################################################################## - - -class PipelineUserType(enum.Enum): - Producer = enum.auto() - Consumer = enum.auto() - - -class PipelineState: - """ - Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. - """ - - def __init__(self, stages: int, count, index, phase): - self._stages = stages - self._count = count - self._index = index - self._phase = phase - - def clone(self) -> "PipelineState": - return PipelineState(self.stages, self._count, self.index, self.phase) - - @property - def index(self) -> Int32: - return self._index - - @property - def count(self) -> Int32: - return self._count - - @property - def stages(self) -> int: - return self._stages - - @property - def phase(self) -> Int32: - return self._phase - - def reset_count(self): - self._count = Int32(0) - - def advance(self): - self._index += 1 - self._count += 1 - - def then_body(index, phase): - new_index = Int32(0) - new_phase = phase ^ 1 - return new_index, new_phase - - def else_body(index, phase): - return index, phase - - self._index, self._phase = if_generate( - self._index == self.stages, - then_body, - else_body, - [self.index, self.phase], - [Int32, Int32], - ) - - def reverse(self): - self._index -= 1 - self._count -= 1 - - def then_body(index, phase): - new_index = Int32(self.stages - 1) - new_phase = phase ^ 1 - return new_index, new_phase - - def else_body(index, phase): - return index, phase - - self._index, self._phase = if_generate( - self._index == -1, - then_body, - else_body, - [self.index, self.phase], - [Int32, Int32], - ) - - def __get_mlir_types__(self): - return [self._count.type, self._index.type, self._phase.type] - - def __extract_mlir_values__(self): - count = self._count - index = self._index - phase = self._phase - return [count.ir_value(), index.ir_value(), phase.ir_value()] - - # This can be overridden by derived classes - def __new_from_mlir_values__(self, values): - return PipelineState( - self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) - ) - - -def make_pipeline_state(type: PipelineUserType, stages: int): - """ - Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. - """ - if type is PipelineUserType.Producer: - return PipelineState( - stages, - Int32(0), - Int32(0), - Int32(1), - ) - elif type is PipelineUserType.Consumer: - return PipelineState( - stages, - Int32(0), - Int32(0), - Int32(0), - ) - else: - assert ( - False - ), "Error: invalid PipelineUserType specified for make_pipeline_state." - - -############################################################################## -# Helper functions -############################################################################## - - -def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): - """ - Fences the mbarrier init and syncs the threadblock or cluster - """ - cute.arch.mbarrier_init_fence() - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # If not using clusters, sync the threadblock - _sync(Agent.ThreadBlock) - else: - # If using clusters, sync the cluster - _sync(Agent.ThreadBlockCluster) - - -def _sync(group: Agent): - """ - Syncs all threads within an agent. - """ - if group is Agent.Thread: - raise NotImplementedError("Error: Not supported.") - elif group is Agent.ThreadBlock: - cute.arch.sync_threads() - elif group is Agent.ThreadBlockCluster: - cute.arch.cluster_arrive() - cute.arch.cluster_wait() - else: - assert ( - False - ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." - - -def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: - """ - Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment - """ - return cute.make_ptr( - Int64, - val.ir_value(), - mem_space=_cute_ir.AddressSpace.smem, - assumed_align=8, - ) - - -# NamedBarrier free functions -def arrive(barrier_id: int, num_threads: int): - """ - The aligned flavor of arrive is used when all threads in the CTA will execute the - same instruction. See PTX documentation. - """ - cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads) - - -def arrive_unaligned(barrier_id: int, num_threads: int): - """ - The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. - """ - llvm.inline_asm( - None, - [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], - "barrier.arrive $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -def wait(barrier_id: int, num_threads: int): - """ - NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. - If synchronizing two warps in a producer/consumer pairing, the arrive count would be - 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer - or consumer are counted for mbarriers, while all threads participating in the sync - are counted for NamedBarriers. - """ - warnings.warn( - "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." - ) - arrive_and_wait() - - -def wait_unaligned(barrier_id: int, num_threads: int): - warnings.warn( - "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." - ) - llvm.inline_asm( - None, - [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], - "barrier.sync $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -def arrive_and_wait(barrier_id: int, num_threads: int): - cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads) - - -def sync(barrier_id: int = 0): - cute.arch.barrier(barrier_id=barrier_id) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py deleted file mode 100644 index 2feed8cc0f1e702557f0c2b21b7582651a6405b8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py +++ /dev/null @@ -1,453 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Optional, Union -import warnings - -import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, if_generate - -from cutlass.pipeline import ( - Agent, - CooperativeGroup, - PipelineOp, - PipelineState, - pipeline_init_wait, - PipelineAsync, -) - -############################################################################## -# Pipeline classes -############################################################################## - - -@dataclass(frozen=True) -class PipelineTmaUmma(PipelineAsync): - """ - PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). - """ - - is_leader_cta: bool - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): - """ - Computes a mask for signaling arrivals to multicasting threadblocks. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) - - tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2 - ) - tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1 - ) - - block_in_cluster_coord_vmnk_peer = ( - cta_in_cluster_coord_vmnk[0] ^ 1, - *cta_in_cluster_coord_vmnk[1:], - ) - tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 - ) - tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 - ) - - return ( - tma_mcast_mask_a - | tma_mcast_mask_b - | tma_mcast_mask_a_peer - | tma_mcast_mask_b_peer - ) - - @staticmethod - def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): - """ - Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. - """ - bidx, bidy, _ = cute.arch.block_idx() - - mma_coord_vmnk = ( - bidx % cute.size(cta_layout_vmnk, mode=[0]), - bidx // cute.size(cta_layout_vmnk, mode=[0]), - bidy, - None, - ) - return mma_coord_vmnk[0] == 0 - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - tx_count: int, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent - :type consumer_group: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.TmaLoad - consumer_type = PipelineOp.TCGen05Mma - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # No mcast mask if not using clusters - producer_mask = None - # All threadblocks are leaders if not using clusters - is_leader_cta = True - else: - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) - is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) - - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - consumer_mask = producer_mask - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineTmaUmma( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - is_leader_cta, - cta_group, - ) - - def consumer_release(self, state: PipelineState): - """ - UMMA consumer release buffer empty, cta_group needs to be provided. - """ - self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(state.index, state.phase), - ) - if_generate( - self.is_leader_cta, - lambda: self.sync_object_full.arrive(state.index, self.producer_mask), - ) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a noop since TMA instruction itself updates the transaction count. - """ - pass - - -@dataclass(frozen=True) -class PipelineAsyncUmma(PipelineAsync): - """ - PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines). - """ - - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def _compute_leading_cta_rank(cta_v_size): - """ - Computes the leading CTA rank. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - return cta_rank_in_cluster // cta_v_size * cta_v_size - - @staticmethod - def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): - """ - Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. - """ - bidx, bidy, _ = cute.arch.block_idx() - mma_coord_vmnk = ( - bidx % cute.size(cta_layout_vmnk, mode=[0]), - bidx // cute.size(cta_layout_vmnk, mode=[0]), - bidy, - None, - ) - return mma_coord_vmnk[0] == 0 - - @staticmethod - def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): - """ - Computes a mask for signaling arrivals to multicasting threadblocks. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) - mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0 - ) - block_in_cluster_coord_vmnk_peer = ( - cta_in_cluster_coord_vmnk[0] ^ 1, - *cta_in_cluster_coord_vmnk[1:], - ) - mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0 - ) - return mask_self | mask_peer - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent - :type consumer_group: CooperativeGroup - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.AsyncThread - consumer_type = PipelineOp.TCGen05Mma - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), - num_stages, - producer, - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - cta_v_size = ( - cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1 - ) - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: - # No mcast mask if we're not using 2CTA tcgen05 MMA - producer_mask = None - consumer_mask = None - else: - # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA - # We need to get the target cta_rank - producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size) - # consumer needs to get the mask to signal - consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk) - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineAsyncUmma( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - cta_group, - ) - - def consumer_release(self, state: PipelineState): - """ - UMMA consumer release buffer empty, cta_group needs to be provided. - """ - self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) - - -@dataclass(frozen=True) -class PipelineUmmaAsync(PipelineAsync): - """ - PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). - """ - - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): - """ - Computes a mask to signal completion of tmem buffers for 2CTA kernels. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) - return cute.make_layout_image_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 - ) - - @staticmethod - def _compute_peer_cta_rank(): - """ - Computes a mask to signal release of tmem buffers for 2CTA kernels. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - return cta_rank_in_cluster // 2 * 2 - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent - :type consumer_group: CooperativeGroup - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.TCGen05Mma - consumer_type = PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # Set mask to None if not using clusters (i.e. 1CTA kernels) - producer_mask = None - else: - producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: - # Set mask to None if not using 2CTA intructions - consumer_mask = None - else: - consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() - - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineUmmaAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - cta_group, - ) - - def producer_commit(self, state: PipelineState): - """ - UMMA producer commit buffer full, cta_group needs to be provided. - """ - self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group) - - def producer_tail(self, state: PipelineState): - """ - Make sure the last used buffer empty signal is visible to producer. - Producer tail is usually executed by producer before exit, to avoid dangling - mbarrier arrive signals after kernel exit. - - :param state: The pipeline state that points to next useful buffer - :type state: PipelineState - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 - - def then_body(): - # Assume state contains that next useful buffer - # So we only need to advance to num_stages - 1 times to last used buffer - for i in range(self.num_stages - 1): - state.advance() - self.producer_acquire(state) - - if_generate(is_leader_cta, then_body) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py deleted file mode 100644 index 5fc19960c9b1ccca84dcc18bca002e2fa2a303ca..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py +++ /dev/null @@ -1,985 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from typing import Type, Tuple -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Optional, Union -import warnings - -import cutlass -import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, Int32, if_generate - -from cutlass.pipeline import ( - Agent, - CooperativeGroup, - PipelineOp, - SyncObject, - MbarrierArray, - TmaStoreFence, - PipelineUserType, - PipelineState, - make_pipeline_state, - pipeline_init_wait, -) - -############################################################################## -# Pipeline classes -############################################################################## - - -@dataclass(frozen=True) -class PipelineAsync: - """PipelineAsync is a generic pipeline class where both the producer and consumer are - AsyncThreads. It also serves as a base class for specialized pipeline classes. - - This class implements a producer-consumer pipeline pattern where both sides operate - asynchronously. The pipeline maintains synchronization state using barrier objects - to coordinate between producer and consumer threads. - - The pipeline state transitions of one pipeline entry(mbarrier) can be represented as: - - .. table:: Pipeline State Transitions - :widths: auto - - +-----------+-----------+-----------+-----------+-----------+-----------+ - | Barrier | State | p.acquire | p.commit | c.wait | c.release | - +===========+===========+===========+===========+===========+===========+ - | empty_bar | empty | | n/a | n/a | - | - +-----------+-----------+-----------+-----------+-----------+-----------+ - | empty_bar | wait | | n/a | n/a | -> empty | - +-----------+-----------+-----------+-----------+-----------+-----------+ - | full_bar | wait | n/a | -> full | | n/a | - +-----------+-----------+-----------+-----------+-----------+-----------+ - | full_bar | full | n/a | - | | n/a | - +-----------+-----------+-----------+-----------+-----------+-----------+ - - Where: - - - p: producer - - c: consumer - - : This action is blocked until transition to a state allow it to proceed by other side - - e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()`` - - .. code-block:: text - - Array of mbarriers as circular buffer: - - Advance Direction - <------------------- - - Producer Consumer - | ^ - V | - +-----------------+ - --|X|X|W|D|D|D|D|R|X|<-. - / +-----------------+ \\ - | | - `------------------------' - - Where: - - - X: Empty buffer (initial state) - - W: Producer writing (producer is waiting for buffer to be empty) - - D: Data ready (producer has written data to buffer) - - R: Consumer reading (consumer is consuming data from buffer) - - **Example:** - - .. code-block:: python - - # Create pipeline with 5 stages - pipeline = PipelineAsync.create( - num_stages=5, # number of pipeline stages - producer_group=producer_warp, - consumer_group=consumer_warp - barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory - ) - - producer, consumer = pipeline.make_participants() - # Producer side - for i in range(num_iterations): - handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage - # Write data to pipeline buffer - handle.commit() # Signal buffer is full - - # Consumer side - for i in range(num_iterations): - handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage - # Read data from pipeline buffer - handle.release() # Signal buffer is empty - """ - - sync_object_full: SyncObject - sync_object_empty: SyncObject - num_stages: int - producer_mask: Optional[Int32] - consumer_mask: Optional[Int32] - - @staticmethod - def _make_sync_object( - barrier_storage: cute.Pointer, - num_stages: int, - agent: tuple[PipelineOp, CooperativeGroup], - tx_count: int = 0, - ) -> SyncObject: - """ - Returns a SyncObject corresponding to an agent's PipelineOp. - """ - if agent[0] in [ - PipelineOp.AsyncThread, - PipelineOp.TmaLoad, - PipelineOp.TCGen05Mma, - PipelineOp.Composite, - PipelineOp.AsyncLoad, - ]: - return MbarrierArray( - barrier_storage=barrier_storage, - num_stages=num_stages, - agent=agent, - tx_count=tx_count, - ) - elif agent[0] is PipelineOp.TmaStore: - # Path taken for AsyncTmaStore - return TmaStoreFence(num_stages=num_stages) - else: - assert False, "Error: Invalid PipelineOp specified." - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, - ): - """Creates and initializes a new PipelineAsync instance. - - This helper function computes necessary attributes and returns an instance of PipelineAsync - with the specified configuration for producer and consumer synchronization. - - :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: int - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent - :type consumer_group: CooperativeGroup - :param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None`` - :type producer_mask: Int32, optional - :param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None`` - :type consumer_mask: Int32, optional - :return: A new PipelineAsync instance - :rtype: PipelineAsync - :raises ValueError: If barrier_storage is not a cute.Pointer instance - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.AsyncThread - consumer_type = PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - pipeline_init_wait() - - return PipelineAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(state.index, state.phase), - ) - - def producer_try_acquire(self, state: PipelineState): - return self.sync_object_empty.try_wait(state.index, state.phase) - - def producer_commit(self, state: PipelineState): - self.sync_object_full.arrive(state.index, self.producer_mask) - - def consumer_wait( - self, state: PipelineState, try_wait_token: Optional[Boolean] = None - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait(state.index, state.phase), - ) - - def consumer_try_wait(self, state: PipelineState): - return self.sync_object_full.try_wait(state.index, state.phase) - - def consumer_release(self, state: PipelineState): - self.sync_object_empty.arrive(state.index, self.consumer_mask) - - def producer_get_barrier(self, state: PipelineState) -> cute.Pointer: - return self.sync_object_full.get_barrier(state.index) - - def producer_tail(self, state: PipelineState): - """ - Make sure the last used buffer empty signal is visible to producer. - Producer tail is usually executed by producer before exit, to avoid dangling - mbarrier arrive signals after kernel exit. - - :param state: The pipeline state that points to next useful buffer - :type state: PipelineState - """ - # Assume state contains that next useful buffer - # So we only need to advance to num_stages - 1 times to last used buffer - for i in range(self.num_stages - 1): - state.advance() - self.producer_acquire(state) - - # Util methods to manage produer and consumer - def make_producer(self): - state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) - return PipelineProducer(self, state, self.sync_object_full.cg) - - def make_consumer(self): - state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) - return PipelineConsumer(self, state, self.sync_object_empty.cg) - - def make_participants(self): - return self.make_producer(), self.make_consumer() - - - -@dataclass(frozen=True) -class PipelineCpAsync(PipelineAsync): - """ - PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops). - """ - - @staticmethod - def create( - barrier_storage: cute.Pointer, - num_stages: Int32, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param producer_mask: Mask for signaling arrives for the producer agent - :type producer_mask: Int32 | None - :param consumer_mask: Mask for signaling arrives for the consumer agent - :type consumer_mask: Int32 | None - """ - producer_type = PipelineOp.AsyncLoad - consumer_type = PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_array_full = PipelineCpAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer - ) - sync_object_array_empty = PipelineCpAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - pipeline_init_wait() - - return PipelineCpAsync( - sync_object_array_full, - sync_object_array_empty, - num_stages, - producer_mask, - consumer_mask, - ) - - -@dataclass(frozen=True) -class PipelineTmaAsync(PipelineAsync): - """ - PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). - """ - - is_signalling_thread: Boolean - - @staticmethod - @cute.jit - def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): - """ - Initialize the empty barrier arrive signal - This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread - """ - # Logic to optimally schedule Empty Arrives - cluster_shape_vmnk = cta_layout_vmnk.shape - - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - - tidx = tidx % 32 - is_signalling_thread = tidx < cute.size(cluster_shape_vmnk) - dst_rank = tidx % cute.size(cluster_shape_vmnk) - - dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank) - cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) - - is_same_row = ( - dst_cta_coord[0] == cur_cta_coord[0] - and dst_cta_coord[1] == cur_cta_coord[1] - and dst_cta_coord[3] == cur_cta_coord[3] - ) - is_same_col = ( - dst_cta_coord[0] == cur_cta_coord[0] - and dst_cta_coord[2] == cur_cta_coord[2] - and dst_cta_coord[3] == cur_cta_coord[3] - ) - - is_same_row_or_col = is_same_row or is_same_col - is_signalling_thread_final = is_signalling_thread and is_same_row_or_col - - return dst_rank, is_signalling_thread_final - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - tx_count: int, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - tidx: Optional[Int32] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent - :type consumer_group: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - :param tidx: thread index to consumer async threads - :type tidx: Int32 | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.TmaLoad - consumer_type = PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - if tidx is None: - tidx, _, _ = cute.arch.thread_idx() - if cta_layout_vmnk is None: - cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) - ( - dst_rank, - is_signalling_thread, - ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - dst_rank = None - else: - dst_rank = dst_rank - - producer_mask = None - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineTmaAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - dst_rank, - is_signalling_thread, - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(state.index, state.phase), - ) - self.sync_object_full.arrive(state.index, self.producer_mask) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a noop since TMA instruction itself updates the transaction count. - """ - pass - - def consumer_release(self, state: PipelineState): - """ - TMA consumer release conditionally signals the empty buffer to the producer. - """ - if_generate( - self.is_signalling_thread, - lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), - ) - - -@dataclass(frozen=True) -class PipelineTmaMultiConsumersAsync(PipelineAsync): - """ - PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers. - """ - - is_leader_cta: bool - sync_object_empty_umma: SyncObject - sync_object_empty_async: SyncObject - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group_umma: CooperativeGroup, - consumer_group_async: CooperativeGroup, - tx_count: int, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent - :type consumer_group_umma: CooperativeGroup - :param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent - :type consumer_group_async: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.TmaLoad - consumer_type = PipelineOp.Composite - consumer_type_umma = PipelineOp.TCGen05Mma - consumer_type_async = PipelineOp.AsyncThread - - if consumer_group_umma.agent != consumer_group_async.agent: - raise ValueError( - "UMMA and AsyncThread consumer groups must be the same agent" - ) - - if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1: - raise ValueError( - f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}" - ) - - consumer_group = CooperativeGroup( - consumer_group_umma.agent, - consumer_group_umma.size + consumer_group_async.size, - ) - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - sync_object_empty_umma = sync_object_empty.recast_to_new_op_type( - consumer_type_umma - ) - sync_object_empty_async = sync_object_empty.recast_to_new_op_type( - consumer_type_async - ) - - # No mcast mask if not using clusters - producer_mask = None - consumer_mask = None - # All threadblocks are leaders if not using clusters - is_leader_cta = True - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineTmaMultiConsumersAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - is_leader_cta, - sync_object_empty_umma, - sync_object_empty_async, - cta_group, - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - """ - TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(state.index, state.phase), - ) - if_generate( - self.is_leader_cta, - lambda: self.sync_object_full.arrive(state.index, self.producer_mask), - ) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a noop since TMA instruction itself updates the transaction count. - """ - pass - - def consumer_release(self, state: PipelineState, op_type: PipelineOp): - if op_type == PipelineOp.TCGen05Mma: - self.sync_object_empty_umma.arrive( - state.index, self.consumer_mask, self.cta_group - ) - elif op_type == PipelineOp.AsyncThread: - self.sync_object_empty_async.arrive(state.index, self.consumer_mask) - else: - raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}") - - -@dataclass(frozen=True) -class PipelineTmaStore(PipelineAsync): - """ - PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers. - """ - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent - :type producer_group: CooperativeGroup - """ - - producer_type = PipelineOp.TmaStore - - producer = (producer_type, producer_group) - - sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) - - return PipelineTmaStore(sync_object_full, None, num_stages, None, None) - - def producer_acquire(self): - self.sync_object_full.wait() - - def producer_commit(self): - self.sync_object_full.arrive() - - def consumer_wait(self): - assert False, "Error: PipelineTmaStore does not have a consumer agent." - - def consumer_release(self): - assert False, "Error: PipelineTmaStore does not have a consumer agent." - - def producer_tail(self): - self.sync_object_full.tail() - - -################################################################# -# Utilities to help user of pipeline to simplify the workflow -################################################################# - - -class ImmutableResourceHandle: - __origin: PipelineAsync - __immutable_state: PipelineState - - def __init__(self, origin: PipelineAsync, immutable_state: PipelineState): - self.__origin = origin - self.__immutable_state = immutable_state - - @property - def index(self): - """Get the index of the current pipeline stage.""" - return self.__immutable_state.index - - @property - def count(self): - """Get the count of how many handles this producer has committed. - This is useful for tracking the number of blocks that have been loaded from gmem. - """ - return self.__immutable_state.count - - def get_origin(self): - """Get the original pipeline this resource handle belongs to.""" - return self.__origin - - def __extract_mlir_values__(self): - """Extract MLIR values from the current state. - - :return: List of MLIR values representing the current state - :rtype: list - """ - # TODO: need to handle pipeline as well - return self.__immutable_state.__extract_mlir_values__() - - def __new_from_mlir_values__(self, values): - """Create a new Producer instance from MLIR values. - - :param values: MLIR values to initialize the state - :type values: Any - :return: New Producer instance with state initialized from values - :rtype: Producer - """ - return self.__class__( - self.__origin, self.__immutable_state.__new_from_mlir_values__(values) - ) - -class PipelineProducer: - """A class representing a producer in an asynchronous pipeline. - - The Producer class manages the producer side of an asynchronous pipeline, handling - synchronization and state management for producing data. It provides methods for - acquiring, committing, and advancing through pipeline stages. - - :ivar __pipeline: The asynchronous pipeline this producer belongs to - :type __pipeline: PipelineAsync - :ivar __state: The current state of the producer in the pipeline - :type __state: PipelineState - :ivar __group: The cooperative group this producer operates in - :type __group: CooperativeGroup - - **Examples:** - - .. code-block:: python - - pipeline = PipelineAsync.create(...) - producer = pipeline.create_producer(producer_group, stages) - for i in range(iterations): - handle = producer.acquire_and_advance() # Wait for buffer to be empty - # Produce data - producer.commit(handle) # Signal data is ready - # An alternative way to do this is: - # handle.commit() # Signal data is ready - """ - - __pipeline: PipelineAsync - __state: PipelineState - __group: CooperativeGroup - - class ImmutableResourceHandle(ImmutableResourceHandle): - @property - def barrier(self): - """Get the barrier pointer for the current pipeline stage. - - :return: Pointer to the barrier for the current stage - :rtype: cute.Pointer - """ - return self.get_origin().producer_get_barrier( - self._ImmutableResourceHandle__immutable_state - ) - - def commit(self): - """Signal that data production is complete for the current stage. - This allows consumers to start processing the data. - """ - self.get_origin().producer_commit( - self._ImmutableResourceHandle__immutable_state - ) - - def __init__(self, pipeline, state, group: CooperativeGroup): - """Initialize a new Producer instance. - - :param pipeline: The pipeline this producer belongs to - :type pipeline: PipelineAsync - :param state: Initial pipeline state - :type state: PipelineState - :param group: The cooperative group for synchronization - :type group: CooperativeGroup - """ - self.__pipeline = pipeline - self.__state = state - self.__group = group - - def acquire( - self, - try_acquire_token: Optional[Boolean] = None, - ) -> ImmutableResourceHandle: - """Wait for the current buffer to be empty before producing data. - This is a blocking operation. - - :param try_acquire_token: Optional token to try to acquire the buffer - :type try_acquire_token: Optional[Boolean] - :return: A handle to the producer for committing the data - :rtype: ImmutableResourceHandle - """ - self.__pipeline.producer_acquire(self.__state, try_acquire_token) - handle = PipelineProducer.ImmutableResourceHandle( - self.__pipeline, self.__state.clone() - ) - return handle - - def advance(self): - """Move to the next pipeline stage.""" - self.__state.advance() - - def acquire_and_advance( - self, try_acquire_token: Optional[Boolean] = None - ) -> ImmutableResourceHandle: - """Wait for the current buffer to be empty before producing data. - Then advance to the next stage. - This is a blocking operation. - - :param try_acquire_token: Optional token to try to acquire the buffer - :type try_acquire_token: Optional[Boolean] - :return: A handle to the producer for committing the data - :rtype: ImmutableResourceHandle - """ - handle = self.acquire(try_acquire_token) - self.advance() - return handle - - def try_acquire(self) -> Boolean: - """Try to acquire the current buffer without blocking. - - :return: True if acquisition was successful, False otherwise - :rtype: Boolean - """ - return self.__pipeline.producer_try_acquire(self.__state) - - def commit(self, handle: Optional[ImmutableResourceHandle] = None): - """Signal that data production is complete for the current stage. - This allows consumers to start processing the data. - """ - if handle is not None: - assert ( - handle.get_origin() is self - ), "ResourceHandle does not belong to this PipelineProducer instance" - handle.commit() - else: - self.__pipeline.producer_commit(self.__state) - - def tail(self): - """Ensure all used buffers are properly synchronized before producer exit. - This should be called before the producer finishes to avoid dangling signals. - """ - self.__pipeline.producer_tail(self.__state) - - def __extract_mlir_values__(self): - """Extract MLIR values from the current state. - - :return: List of MLIR values representing the current state - :rtype: list - """ - # TODO: need to handle pipeline as well - return self.__state.__extract_mlir_values__() - - def __new_from_mlir_values__(self, values): - """Create a new Producer instance from MLIR values. - - :param values: MLIR values to initialize the state - :type values: Any - :return: New Producer instance with state initialized from values - :rtype: Producer - """ - return PipelineProducer( - self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group - ) - -class PipelineConsumer: - """A class representing a consumer in an asynchronous pipeline. - - The Consumer class manages the consumer side of an asynchronous pipeline, handling - synchronization and state management for consuming data. It provides methods for - waiting, releasing, and advancing through pipeline stages. - - :ivar __pipeline: The asynchronous pipeline this consumer belongs to - :type __pipeline: PipelineAsync - :ivar __state: The current state of the consumer in the pipeline - :type __state: PipelineState - :ivar __group: The cooperative group this consumer operates in - :type __group: CooperativeGroup - - **Examples:** - .. code-block:: python - - pipeline = PipelineAsync.create(...) - consumer = pipeline.create_consumer(consumer_group, stages) - for i in range(iterations): - handle = consumer.wait_and_advance() # Wait for data to be ready - # Consume data - consumer.release(handle) # Signal buffer is empty - # An alternative way to do this is: - # handle.release() # Signal buffer is empty - """ - - __pipeline: PipelineAsync - __state: PipelineState - __group: CooperativeGroup - - class ImmutableResourceHandle(ImmutableResourceHandle): - def release(self): - """Signal that data production is complete for the current stage. - This allows consumers to start processing the data. - """ - self.get_origin().consumer_release( - self._ImmutableResourceHandle__immutable_state - ) - - def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): - """Initialize a new Consumer instance. - - :param pipeline: The pipeline this consumer belongs to - :type pipeline: PipelineAsync - :param state: Initial pipeline state - :type state: PipelineState - :param group: The cooperative group for synchronization - :type group: CooperativeGroup - """ - self.__pipeline = pipeline - self.__group = group - self.__state = state - - def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: - """Wait for data to be ready in the current buffer. - This is a blocking operation. - - :param try_wait_token: Optional token to try to wait for the buffer - :type try_wait_token: Optional[Boolean] - :return: A handle to the consumer for releasing the data - :rtype: PipelineConsumerHandle - """ - self.__pipeline.consumer_wait(self.__state, try_wait_token) - handle = PipelineConsumer.ImmutableResourceHandle( - self.__pipeline, self.__state.clone() - ) - return handle - - def advance(self): - """Move to the next pipeline stage.""" - self.__state.advance() - - def wait_and_advance( - self, try_wait_token: Optional[Boolean] = None - ) -> ImmutableResourceHandle: - """Wait for data to be ready in the current buffer. - Then advance to the next stage. - This is a blocking operation. - - :param try_wait_token: Optional token to try to wait for the buffer - :type try_wait_token: Optional[Boolean] - :return: A handle to the consumer for releasing the data - :rtype: PipelineConsumerHandle - """ - handle = self.wait(try_wait_token) - self.advance() - return handle - - def try_wait(self) -> Boolean: - """Try to check if data is ready without blocking. - - :return: True if data is ready, False otherwise - :rtype: Boolean - """ - return self.__pipeline.consumer_try_wait(self.__state) - - def release(self, handle: Optional[ImmutableResourceHandle] = None): - """Signal that data consumption is complete for the current stage. - This allows producers to start producing new data. - """ - if handle is not None: - assert ( - handle.get_origin() is self - ), "ResourceHandle does not belong to this PipelineConsumer instance" - handle.release() - else: - self.__pipeline.consumer_release(self.__state) - - def __extract_mlir_values__(self): - """Extract MLIR values from the current state. - - :return: List of MLIR values representing the current state - :rtype: list - """ - return self.__state.__extract_mlir_values__() - - def __new_from_mlir_values__(self, values): - """Create a new Consumer instance from MLIR values. - - :param values: MLIR values to initialize the state - :type values: Any - :return: New Consumer instance with state initialized from values - :rtype: Consumer - """ - # TODO: need to call pipeline.__new_from_mlir_values__ recursively - return PipelineConsumer( - self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py deleted file mode 100644 index e5ee5777cad35487f30b8705ff19747405d11194..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py +++ /dev/null @@ -1,311 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import ctypes -from math import prod -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Type, Union - -from cutlass.cute.typing import ( - Numeric, - Boolean, - Float, - Integer, - TFloat32, - Float8E4M3B11FNUZ, - Float8E4M3FN, - Float8E5M2, - Float8E8M0FNU, - Float4E2M1FN, - Tensor, -) -from cutlass.cute.runtime import from_dlpack -import cutlass.cute as cute -import torch -import cuda.bindings.driver as cuda - - -def dtype(ty: Type[Numeric]): - """ - Return the corresponding torch.dtype per the given DSL type - """ - torch_dtype = getattr(torch, ty.__name__.lower(), None) - - torch_type_map = { - Boolean: torch.bool, - # TFloat32 is just alias of float32 - TFloat32: torch.float32, - Float8E5M2: torch.float8_e5m2, - Float8E4M3FN: torch.float8_e4m3fn, - Float8E4M3B11FNUZ: torch.float8_e4m3fnuz, - } - if torch_dtype is None: - torch_dtype = torch_type_map.get(ty) - - if torch_dtype is None: - raise TypeError(f"{ty} is not supported by torch") - return torch_dtype - - -def as_tensor(pointer, shape, torch_type): - """Convert a pointer to a torch tensor""" - if torch_type.itemsize == 1: - cytype = ctypes.c_uint8 - elif torch_type.itemsize == 2: - cytype = ctypes.c_uint16 - elif torch_type.itemsize == 4: - cytype = ctypes.c_uint32 - elif torch_type.itemsize == 8: - cytype = ctypes.c_uint64 - else: - raise ValueError(f"Unsupported torch dtype: {torch_type}") - cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) - arr = (cpointer._type_ * prod(shape)).from_address( - ctypes.addressof(cpointer.contents) - ) - return torch.frombuffer(arr, dtype=torch_type).view(*shape) - - -@dataclass -class ScalarInitConfig: - """Configuration for scalar initialization""" - - value: float = 0.0 - - -@dataclass -class RandomInitConfig: - """Configuration for random initialization""" - - min_val: int = -2 - max_val: int = 2 - - -@dataclass -class GaussianInitConfig: - """Configuration for Gaussian initialization""" - - mean: float = 0.0 - std: float = 1.0 - scale: float = 1.0 - - -class TensorInitType(Enum): - """Enumeration of tensor initialization types""" - - SKIP = "skip" - SCALAR = "scalar" - RANDOM = "random" - GAUSSIAN = "gaussian" - - -def create_and_permute_torch_tensor( - shape, - dtype: "torch.dtype", - permute_order=None, - init_type: TensorInitType = TensorInitType.RANDOM, - init_config: Optional[ - Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] - ] = None, - device: Optional[torch.device] = None, -) -> "torch.Tensor": - """ - Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config - """ - init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32 - init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device) - if init_type == TensorInitType.SKIP: - assert init_config is None - f32_torch_tensor = init_torch_tensor - elif init_type == TensorInitType.SCALAR: - if init_config is None: - init_config = ScalarInitConfig() - else: - if not isinstance(init_config, ScalarInitConfig): - raise ValueError("init_config must be ScalarInitConfig()") - f32_torch_tensor = init_torch_tensor.fill_(init_config.value) - elif init_type == TensorInitType.RANDOM: - if init_config is None: - init_config = RandomInitConfig() - else: - if not isinstance(init_config, RandomInitConfig): - raise ValueError("init_config must be RandomInitConfig()") - f32_torch_tensor = init_torch_tensor.random_( - init_config.min_val, init_config.max_val - ).to(dtype=torch.float32) - elif init_type == TensorInitType.GAUSSIAN: - if init_config is None: - init_config = GaussianInitConfig() - else: - if not isinstance(init_config, GaussianInitConfig): - raise ValueError("init_config must be GaussianInitConfig()") - f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std) - f32_torch_tensor = f32_torch_tensor * init_config.scale - else: - raise ValueError(f"Invalid init type: {init_type}") - - if permute_order is not None: - f32_torch_tensor = f32_torch_tensor.permute(permute_order) - - dtype_torch_tensor = f32_torch_tensor.to(dtype=dtype) - - return dtype_torch_tensor - - -def convert_cute_tensor( - f32_torch_tensor: "torch.Tensor", - cute_tensor: Tensor, - dtype: Type[Numeric], - is_dynamic_layout: bool = True, -) -> Tensor: - """ - Change the value of the cute tensor to make its value converted from a fp32 torch tensor. - Used for fp8 types tensor creatation now. - """ - # if torch_tensor is on cpu, create a gpu copy - if f32_torch_tensor.device.type == "cpu": - f32_torch_tensor = f32_torch_tensor.cuda() - - # Fp8 type need explicit type conversion - if dtype in { - Float8E5M2, - Float8E4M3FN, - Float8E8M0FNU, - Float4E2M1FN, - }: - fp32_cute_tensor = from_dlpack(f32_torch_tensor) - if is_dynamic_layout: - fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic( - f32_torch_tensor.dim_order()[-1] - ) - # Copy and convert from f32 cute tensor to dtype cute tensor - cute.testing.convert(fp32_cute_tensor, cute_tensor) - return cute_tensor - - -def default_stream() -> cuda.CUstream: - """ - Get default CUstream from torch stream - """ - torch_stream = torch.cuda.default_stream() - stream = cuda.CUstream(torch_stream.cuda_stream) - return stream - - -def current_stream() -> cuda.CUstream: - """ - Get current CUstream from torch stream - """ - torch_stream = torch.cuda.current_stream() - stream = cuda.CUstream(torch_stream.cuda_stream) - return stream - - -def matrix( - l: int, - mode0: int, - mode1: int, - is_mode0_major: bool, - cutlass_dtype: Type[Numeric], - init_type: TensorInitType = TensorInitType.RANDOM, - init_config: Optional[ - Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] - ] = None, - device: Optional[torch.device] = None, -) -> torch.Tensor: - """ - Create a torch tensor for matrix - - :param l: length of the matrix - :param mode0: mode0 of the matrix - :param mode1: mode1 of the matrix - :param is_mode0_major: whether the matrix is mode0 major - :param cutlass_dtype: cutlass dtype of the matrix - :param init_type: type of initialization - :param init_config: configuration for initialization - :param device: target torch device - """ - - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - - if cutlass_dtype.is_float and cutlass_dtype.width <= 8: - torch_dtype = torch.int8 - else: - torch_dtype = dtype(cutlass_dtype) - - if init_type == TensorInitType.RANDOM and init_config is None: - if torch_dtype.is_signed: - min_val = -2 - max_val = 2 - else: - min_val = 0 - max_val = 4 - init_config = RandomInitConfig(min_val=min_val, max_val=max_val) - - # Create dtype torch tensor - torch_tensor = create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=init_type, - init_config=init_config, - device=device, - ) - - return torch_tensor - - -def cute_tensor_like( - data_ref: torch.Tensor, - cutlass_dtype: Type[Numeric], - is_dynamic_layout: bool, - assumed_align: Optional[int] = None, -) -> tuple[Tensor, torch.Tensor]: - """ - Create a cute tensor use a torch tensor as the data source - - :param data_ref: torch tensor as the data source - :param cutlass_dtype: cutlass dtype of the cute tensor - :param is_dynamic_layout: whether the cute tensor uses dynamic layout - :param assumed_align: assumed alignment of the cute tensor - """ - - # allocate device buffer for cute tensor - if cutlass_dtype.is_float and cutlass_dtype.width <= 8: - torch_dtype = torch.int8 - else: - torch_dtype = dtype(cutlass_dtype) - torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda") - - # create cute tensor using the device buffer - cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align) - cute_tensor.element_type = cutlass_dtype - if is_dynamic_layout: - for i, stride in enumerate(torch_tensor.stride()): - if stride == 1: - leading_dim = i - break - cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) - - # initialize the cute tensor data - if cutlass_dtype.is_float and cutlass_dtype.width <= 8: - cute_tensor = convert_cute_tensor( - data_ref.to(dtype=torch.float32), - cute_tensor, - cutlass_dtype, - is_dynamic_layout, - ) - else: - torch_tensor.copy_(data_ref.to(dtype=torch_dtype)) - - return cute_tensor, torch_tensor diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py deleted file mode 100644 index aec0a186d7a8fc18d65637e97905c7cd5702310d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .static_persistent_tile_scheduler import ( - WorkTileInfo, - PersistentTileSchedulerParams, - StaticPersistentTileScheduler, -) - -from .hardware_info import ( - HardwareInfo, -) - -from .blackwell_helpers import ( - compute_epilogue_tile_shape, - get_smem_store_op, - get_tmem_load_op, - get_num_tmem_alloc_cols, - make_smem_layout_a, - make_smem_layout_b, - make_smem_layout_epi, - make_trivial_tiled_mma, - make_blockscaled_trivial_tiled_mma, -) - -from .hopper_helpers import ( - sm90_get_smem_store_op, -) - -from .blockscaled_layout import ( - BlockScaledBasicChunk, - tile_atom_to_shape_SF, - make_smem_layout_sfa, - make_smem_layout_sfb, - make_tmem_layout_sfa, - make_tmem_layout_sfb, -) - -from .grouped_gemm_tile_scheduler_helper import ( - GroupSearchResult, - GroupedGemmGroupSearchState, - GroupedGemmTileSchedulerHelper, - create_initial_search_state, -) - -from .tensormap_manager import ( - TensorMapUpdateMode, - TensorMapManager, -) - -from .smem_allocator import SmemAllocator - -from .layout import LayoutEnum - -from .smem_capacity import ( - get_smem_capacity_in_bytes, -) - -from .distributed_helpers import ( - spin_lock_wait, - spin_lock_multimem_arrive, - multimem_ld_reduce_8xf16, - multimem_ld_reduce_4xf32, - multimem_ld_reduce_8xbf16, - multimem_ld_reduce_16xe4m3, - multimem_ld_reduce_16xe5m2, - multimem_st_4xb32, - sm_wise_inter_gpu_multimem_barrier, -) - -__all__ = [ - "get_smem_capacity_in_bytes", - "SmemAllocator", - "LayoutEnum", - "WorkTileInfo", - "PersistentTileSchedulerParams", - "StaticPersistentTileScheduler", - "TensorMapUpdateMode", - "TensorMapManager", - "GroupSearchResult", - "GroupedGemmGroupSearchState", - "create_initial_search_state", - "GroupedGemmTileSchedulerHelper", - "HardwareInfo", -] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py deleted file mode 100644 index 1341756f3584f89b0c201631445beb91c34dc29e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from enum import Enum -from typing_extensions import deprecated -import warnings - - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024 - SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value, - "sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value, - "sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value, -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py deleted file mode 100644 index 6fb6bf4dbfa3e73f058037e79b0999697d720502..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ /dev/null @@ -1,1135 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from enum import Enum -from math import log2, ceil -from typing import List, Type, Union, Tuple -from typing_extensions import deprecated -import warnings - -from cutlass.cutlass_dsl import ( - Float16, - BFloat16, - TFloat32, - Float32, - Uint8, - Int8, - Float8E4M3FN, - Float8E5M2, - Float4E2M1FN, - Numeric, - NumericMeta, - dsl_user_op, -) -import cutlass.cute as cute -from cutlass.cute.nvgpu.common import CopyUniversalOp -from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp -from cutlass.cute.nvgpu.tcgen05 import ( - MmaF16BF16Op, - MmaTF32Op, - MmaI8Op, - MmaFP8Op, - MmaMXF8Op, - MmaMXF4Op, - MmaMXF4NVF4Op, - OperandSource, - OperandMajorMode, - CtaGroup, - Ld16x64bOp, - Ld16x128bOp, - Ld16x256bOp, - Ld16x32bx2Op, - Ld32x32bOp, - Repetition, - Pack, - find_tmem_tensor_col_offset, - SmemLayoutAtomKind, - make_smem_layout_atom, - tile_to_mma_shape, - is_tmem_load, - get_tmem_copy_properties, -) -from cutlass.cute.nvgpu.cpasync import ( - CopyBulkTensorTileG2SMulticastOp, - CopyBulkTensorTileG2SOp, -) -from cutlass.utils.layout import LayoutEnum - - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, - "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, -} - - -@dsl_user_op -def compute_epilogue_tile_shape( - cta_tile_shape: cute.Shape, - use_2cta_instrs: bool, - layout_d: LayoutEnum, - elem_ty_d: Type[Numeric], - *, - layout_c: LayoutEnum = None, - elem_ty_c: Union[Type[Numeric], None] = None, - loc=None, - ip=None, -) -> cute.Tile: - """Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. - - :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile, where - cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1] - corresponds to the width (N) of the tile. - :type cta_tile_shape: cute.Shape - :param use_2cta_instrs: A flag indicating whether the configuration is for a 2SM setup. - :type use_2cta_instrs: bool - :param layout_d: The layout enum of the output tensor D. - :type layout_d: LayoutEnum - :param elem_ty_d: The element type of output tensor D. - :type elem_ty_d: Type[Numeric] - :param layout_c: The layout enum of the input tensor C. Defaults to None. - :type layout_c: LayoutEnum, optional - :param elem_ty_c: The element type for input tensor C. Defaults to None. - :type elem_ty_c: Union[Type[Numeric], None], optional - - :return: Returns epilog tiler, which is used in subsequent epilog partitions. - :rtype: cute.Tile - - :raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions. - """ - - def validate_type(ty, ty_name): - if not isinstance(ty, NumericMeta): - raise TypeError(f"{ty_name} must be Numeric, but got {ty}") - - validate_type(elem_ty_d, "elem_ty_d") - if elem_ty_c is not None: - validate_type(elem_ty_c, "elem_ty_c") - - cta_m, cta_n = cta_tile_shape[:2] - (warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1) - disable_source = elem_ty_c == None - max_bits = ( - elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width) - ) - - dp_full = 32 - tile_m = min(cta_m, dp_full * warp_m) - n_perf = 0 - if disable_source: - if max_bits == 4: - compute_elts = 8192 - else: - compute_elts = 4096 - n_perf = compute_elts // tile_m - else: - if max_bits == 32: - n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32 - elif max_bits == 16: - n_perf = 32 if cta_n <= 128 else 64 - else: - n_perf = 64 - - d_is_m_major = layout_d.is_m_major_c() - c_is_m_major = True if layout_c is None else layout_c.is_m_major_c() - - n_min_d = ( - 8 * warp_n - if d_is_m_major - else (128 * warp_n if elem_ty_d.width == 6 else 128 // elem_ty_d.width * warp_n) - ) - n_min_c = ( - 8 * warp_n - if (c_is_m_major or disable_source) - else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n) - ) - tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d)) - - if cta_n < n_min_c or cta_n < n_min_d: - raise ValueError(f"CTA tile too small: {cta_tile_shape=}") - - # stride by tmem warp layout and return a by-mode tiler - tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip) - tile_n_layout = cute.make_layout( - (tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip - ) - return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip)) - - -@dsl_user_op -def get_smem_store_op( - layout_d: LayoutEnum, - elem_ty_d: Type[Numeric], - elem_ty_acc: Type[Numeric], - tiled_tmem_load: cute.TiledCopy, - *, - loc=None, - ip=None, -) -> cute.CopyAtom: - """Selects the largest vectorized smem store atom available subject to - constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership. - - :param layout_d: The layout enum of the output tensor D. - :type layout_d: LayoutEnum - :param elem_ty_d: The element type for output tensor D. - :type elem_ty_d: Type[Numeric] - :param elem_ty_acc: The element type for accumulator. - :type elem_ty_acc: Type[Numeric] - :param tiled_tmem_load: An instance of TiledCopy that represents the tmem load operation. - :type tiled_tmem_load: cute.TiledCopy - - :return: Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. - :rtype: cute.CopyAtom - """ - - def validate_type(ty, ty_name): - if not isinstance(ty, NumericMeta): - raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") - - validate_type(elem_ty_d, "elem_ty_d") - validate_type(elem_ty_acc, "elem_ty_acc") - - is_m_major = layout_d.is_m_major_c() - is_n_major = layout_d.is_n_major_c() - - if not is_tmem_load(tiled_tmem_load): - return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) - - num_dp, num_bits, num_rep, pack = get_tmem_copy_properties(tiled_tmem_load) - - use_stmatrix_m8n8_4x = ( - all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 32, - is_n_major, - num_dp == 16, - num_bits == 128, - num_rep in (2, 4, 8, 16, 32, 64), - pack == Pack.NONE, - ] - ) - or all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 16, - num_dp == 16, - num_bits == 256, - num_rep in (2, 4, 8, 16, 32), - pack == Pack.NONE, - ] - ) - or all( - [ - elem_ty_acc.width == 16, - elem_ty_d.width == 16, - num_dp == 16, - num_bits == 128, - num_rep in (2, 4, 8, 16, 32, 64), - pack == Pack.PACK_16b_IN_32b, - ] - ) - ) - use_stmatrix_m16n8_4x = all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 8, - is_m_major, - num_dp == 16, - num_bits == 256, - num_rep in (4, 8, 16, 32), - pack == Pack.NONE, - ] - ) - use_stmatrix_m8n8_2x = ( - all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 32, - is_n_major, - num_dp == 16, - num_bits == 128, - num_rep == 1, - pack == Pack.NONE, - ] - ) - or all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 16, - num_dp == 16, - num_bits == 256, - num_rep == 1, - pack == Pack.NONE, - ] - ) - or all( - [ - elem_ty_acc.width == 16, - elem_ty_d.width == 16, - num_dp == 16, - num_bits == 128, - num_rep == 1, - pack == Pack.PACK_16b_IN_32b, - ] - ) - ) - use_stmatrix_m16n8_2x = all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 8, - is_m_major, - num_dp == 16, - num_bits == 256, - num_rep == 2, - pack == Pack.NONE, - ] - ) - use_stmatrix_m16n8_1x = all( - [ - elem_ty_acc.width == 32, - elem_ty_d.width == 8, - is_m_major, - num_dp == 16, - num_bits == 256, - num_rep == 1, - pack == Pack.NONE, - ] - ) - - if use_stmatrix_m8n8_4x: - op = StMatrix8x8x16bOp(is_m_major, 4) - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - elif use_stmatrix_m8n8_2x: - op = StMatrix8x8x16bOp(is_m_major, 2) - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - elif use_stmatrix_m16n8_4x: - op = StMatrix16x8x8bOp(4) - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - elif use_stmatrix_m16n8_2x: - op = StMatrix16x8x8bOp(2) - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - elif use_stmatrix_m16n8_1x: - op = StMatrix16x8x8bOp(1) - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - else: - op = CopyUniversalOp() - return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) - - -@dsl_user_op -def get_tmem_load_op( - cta_tile_shape: cute.Shape, - layout_d: LayoutEnum, - elem_ty_d: Type[Numeric], - elem_ty_acc: Type[Numeric], - epi_tile: cute.Tile, - use_2cta_instrs: bool, - *, - loc=None, - ip=None, -) -> cute.CopyAtom: - """Finds a performant TMEM_LOAD copy op for the selected epilogue - tile (epi_tile), element types, and tcgen05.mma instruction used. - - :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile. - :type cta_tile_shape: cute.Shape - :param layout_d: The layout enum of the output tensor D. - :type layout_d: LayoutEnum - :param elem_ty_d: The element type for output tensor D. - :type elem_ty_d: Type[Numeric] - :param elem_ty_acc: The element type for accumulation. - :type elem_ty_acc: Type[Numeric] - :param epi_tile: The epilogue tile configuration. - :type epi_tile: cute.Tile - :param use_2cta_instrs: A flag indicating whether the configuration is for 2 SMs. - :type use_2cta_instrs: bool - - :return: An instance of Sm100TmemLoad with the computed configuration. - :rtype: cute.CopyAtom - - :raises ValueError: If the function cannot handle the given combination of accumulation - and dimension types, or if it cannot determine the appropriate configuration based on - the input parameters. - """ - is_m_major = layout_d.is_m_major_c() - - acc_bits = elem_ty_acc.width - d_bits = elem_ty_d.width - - tmem_warp_shape_mn = ( - (2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1) - ) - epilog_tile_shape_mn = cute.product_each( - cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip - ) - epilog_warp_tile_shape_mn = cute.shape_div( - epilog_tile_shape_mn, tmem_warp_shape_mn, loc=loc, ip=ip - ) - - num_dp = cute.size(epilog_warp_tile_shape_mn[0], loc=loc, ip=ip) - if num_dp not in {16, 32}: - raise ValueError("Cta tile and 2sm config does not generate correct num dp.") - - num_col_bits = cute.size(epilog_warp_tile_shape_mn[1], loc=loc, ip=ip) * acc_bits - - tmem_dp = 0 - tmem_bit = 0 - tmem_rep = 0 - tmem_pack16b = False - if acc_bits == 32 and d_bits == 32: - if num_dp == 16: - if is_m_major: - tmem_dp = 16 - tmem_bit = 256 - else: - tmem_dp = 16 - tmem_bit = 128 - else: - tmem_dp = 32 - tmem_bit = 32 - elif acc_bits == 32 and d_bits == 16: - if num_dp == 16: - if is_m_major: - tmem_dp = 16 - tmem_bit = 256 - else: - tmem_dp = 16 - tmem_bit = 256 - else: - if is_m_major: - tmem_dp = 16 - tmem_bit = 256 - else: - tmem_dp = 32 - tmem_bit = 32 - elif acc_bits == 32 and d_bits == 8: - if num_dp == 16: - if is_m_major: - tmem_dp = 16 - tmem_bit = 256 - else: - tmem_dp = 16 - tmem_bit = 32 - else: - if is_m_major: - tmem_dp = 16 - tmem_bit = 256 - else: - tmem_dp = 32 - tmem_bit = 32 - elif acc_bits == 16 and d_bits == 16: - tmem_pack16b = True - if num_dp == 16: - if is_m_major: - tmem_dp = 16 - tmem_bit = 128 - else: - tmem_dp = 16 - tmem_bit = 128 - else: - if is_m_major: - tmem_dp = 16 - tmem_bit = 128 - else: - tmem_dp = 32 - tmem_bit = 32 - elif acc_bits == 32 and d_bits == 6: - if not num_dp == 32: - raise ValueError("Num dp must be 32.") - tmem_dp = 32 - tmem_bit = 32 - elif acc_bits == 32 and d_bits == 4: - if not num_dp == 32: - raise ValueError("Num dp must be 32.") - tmem_dp = 32 - tmem_bit = 32 - else: - raise ValueError( - f"Can not handle acc/d type combination: {elem_ty_acc=}, {elem_ty_d=}" - ) - - num_bit_div = tmem_bit - if tmem_dp == 16 and tmem_bit == 32: - num_bit_div = 64 - - if (num_col_bits % (num_bit_div * 128) == 0) and ( - (tmem_dp == 16 and tmem_bit == 64) - or (tmem_dp == 16 and tmem_bit == 32) - or (tmem_dp == 32 and tmem_bit == 32) - ): - tmem_rep = 128 - elif (num_col_bits % (num_bit_div * 64) == 0) and ( - (tmem_dp == 16 and tmem_bit == 128) - or (tmem_dp == 16 and tmem_bit == 64) - or (tmem_dp == 16 and tmem_bit == 32) - or (tmem_dp == 32 and tmem_bit == 32) - ): - tmem_rep = 64 - elif num_col_bits % (num_bit_div * 32) == 0: - tmem_rep = 32 - elif num_col_bits % (num_bit_div * 16) == 0: - tmem_rep = 16 - elif num_col_bits % (num_bit_div * 8) == 0: - tmem_rep = 8 - elif num_col_bits % (num_bit_div * 4) == 0: - tmem_rep = 4 - elif num_col_bits % (num_bit_div * 2) == 0: - tmem_rep = 2 - elif num_col_bits % (num_bit_div * 1) == 0: - tmem_rep = 1 - else: - raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.") - - if tmem_dp == 16 and tmem_bit == 64: - op = Ld16x64bOp( - Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE - ) - return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) - elif tmem_dp == 16 and tmem_bit == 128: - op = Ld16x128bOp( - Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE - ) - return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) - elif tmem_dp == 16 and tmem_bit == 256: - op = Ld16x256bOp( - Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE - ) - return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) - elif tmem_dp == 16 and tmem_bit == 32: - op = Ld16x32bx2Op( - Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE - ) - return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) - - elif tmem_dp == 32 and tmem_bit == 32: - op = Ld32x32bOp( - Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE - ) - return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) - else: - raise ValueError() - - -def get_num_tmem_alloc_cols( - tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], rounding=True -) -> int: - """Get the total number of TMEM allocation columns for the given TMEM tensors. - - :param tmem_tensors: The TMEM tensors to get the number of allocation columns for. - :type tmem_tensors: Union[cute.Tensor, List[cute.Tensor]] - :param rounding: Whether to round up the number of allocation columns to the nearest power of 2. - :type rounding: bool - - :return: The total number of TMEM allocation columns. - :rtype: int - - :raises ValueError: If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32. - """ - # Turn tmem_tensors into a list - if isinstance(tmem_tensors, cute.Tensor): - tmem_tensors = [tmem_tensors] - - # For each tensor in tmem_tensors, find the tmem_tensor_col_offset - num_tmem_alloc_cols_per_tensor = [ - find_tmem_tensor_col_offset(t) for t in tmem_tensors - ] - - # Sum up the num_tmem_alloc_cols_per_tensor - num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor) - - # Round up num_tmem_cols_total to the nearest power of 2 - if rounding: - num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols)) - - # Validate the number of TMEM allocation columns - SM100_TMEM_CAPACITY_COLUMNS = 512 - SM100_TMEM_MIN_ALLOC_COLUMNS = 32 - if ( - num_tmem_alloc_cols > SM100_TMEM_CAPACITY_COLUMNS - or num_tmem_alloc_cols < SM100_TMEM_MIN_ALLOC_COLUMNS - ): - raise ValueError( - f"TMEM allocation columns {num_tmem_alloc_cols} exceeds the maximum capacity of {SM100_TMEM_CAPACITY_COLUMNS} or less than {SM100_TMEM_MIN_ALLOC_COLUMNS}" - ) - return num_tmem_alloc_cols - - -def get_smem_layout_atom_ab( - major_mode: OperandMajorMode, - element_type: Type[Numeric], - smem_shape_mn_k: Tuple[int, int], - *, - loc=None, - ip=None, -) -> SmemLayoutAtomKind: - """Simple heuristics to select the optimal SMEM layout atom based on the - majorness, the data type, and the major mode size. - - :param major_mode: The major mode for the SMEM tensor is K major. - :type major_mode: OperandMajorMode - :param element_type: The element type for the SMEM tensor. - :type element_type: Type[Numeric] - :param smem_shape_mn_k: The shape of the SMEM tensor. - :type smem_shape_mn_k: Tuple[int, int] - - :return: The SMEM layout atom kind - :rtype: SmemLayoutAtomKind - """ - is_k_major = major_mode == OperandMajorMode.K - major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0] - - assert major_mode_size % 8 == 0 - sw128_num_contiguous_bits = 1024 - sw64_num_contiguous_bits = 512 - sw32_num_contiguous_bits = 256 - inter_num_contiguous_bits = 128 - major_mode_size_bits = major_mode_size * element_type.width - assert major_mode_size_bits % inter_num_contiguous_bits == 0 - - if not is_k_major: - if (element_type.width == 32) and ( - major_mode_size_bits % sw128_num_contiguous_bits == 0 - ): - return SmemLayoutAtomKind.MN_SW128_32B - if major_mode_size_bits % sw128_num_contiguous_bits == 0: - return SmemLayoutAtomKind.MN_SW128 - if major_mode_size_bits % sw64_num_contiguous_bits == 0: - return SmemLayoutAtomKind.MN_SW64 - if major_mode_size_bits % sw32_num_contiguous_bits == 0: - return SmemLayoutAtomKind.MN_SW32 - return SmemLayoutAtomKind.MN_INTER - if major_mode_size_bits % sw128_num_contiguous_bits == 0: - return SmemLayoutAtomKind.K_SW128 - if major_mode_size_bits % sw64_num_contiguous_bits == 0: - return SmemLayoutAtomKind.K_SW64 - if major_mode_size_bits % sw32_num_contiguous_bits == 0: - return SmemLayoutAtomKind.K_SW32 - return SmemLayoutAtomKind.K_INTER - - -@dsl_user_op -def make_smem_layout_a( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - a_dtype: Type[Numeric], - num_stages: int, - *, - loc=None, - ip=None, -) -> Union[cute.Layout, cute.ComposedLayout]: - """This function helps with: - 1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler. - 2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size. - 3. cute.Tile the SMEM layout atom to the MMA tile shape. - 4. Stage the SMEM layout based on the number of stages. - - :param tiled_mma: The tiled MMA used to partition tensor A - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The MMA tile shape - :type mma_tiler_mnk: cute.cute.Tile - :param a_dtype: The element type for tensor A - :type a_dtype: Type[Numeric] - :param num_stages: The number of pipeline stages for tensor A - :type num_stages: int - - :return: SMEM layout for tensor A - :rtype: Union[cute.Layout, cute.ComposedLayout] - """ - - is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K - a_smem_shape = tiled_mma.partition_shape_A( - cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) - ) - a_smem_shape_mn_k = ( - cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], - cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], - ) - a_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_ab( - tiled_mma.op.a_major_mode, - a_dtype, - a_smem_shape_mn_k, - loc=loc, - ip=ip, - ), - a_dtype, - loc=loc, - ip=ip, - ) - a_smem_layout_staged = tile_to_mma_shape( - a_smem_layout_atom, - cute.append(a_smem_shape, num_stages, loc=loc, ip=ip), - order=((1, 0, 2) if not is_k_major else (0, 1, 2)), - loc=loc, - ip=ip, - ) - return a_smem_layout_staged - - -@dsl_user_op -def make_smem_layout_b( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - b_dtype: Type[Numeric], - num_stages: int, - *, - loc=None, - ip=None, -) -> Union[cute.Layout, cute.ComposedLayout]: - """This function helps: - 1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler. - 2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size. - 3. cute.Tile the SMEM layout atom to the MMA tile shape. - 4. Stage the SMEM layout based on the number of stages. - - :param tiled_mma: The tiled MMA which is used to partition the B tensor. - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The MMA tile shape. - :type mma_tiler_mnk: cute.cute.Tile - :param b_dtype: The element type for the B tensor. - :type b_dtype: Type[Numeric] - :param num_stages: The stage of the B tensor. - :type num_stages: int - - :return: SMEM layout for the B tensor. - :rtype: Union[cute.Layout, cute.ComposedLayout] - """ - - is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K - b_smem_shape = tiled_mma.partition_shape_B( - cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip) - ) - b_smem_shape_nk = ( - cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1], - cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2], - ) - b_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_ab( - tiled_mma.op.b_major_mode, - b_dtype, - b_smem_shape_nk, - loc=loc, - ip=ip, - ), - b_dtype, - loc=loc, - ip=ip, - ) - b_smem_layout_staged = tile_to_mma_shape( - b_smem_layout_atom, - cute.append(b_smem_shape, num_stages, loc=loc, ip=ip), - order=((1, 0, 2) if not is_k_major else (0, 1, 2)), - loc=loc, - ip=ip, - ) - - return b_smem_layout_staged - - -@dsl_user_op -def get_smem_layout_atom_epi( - layout: LayoutEnum, - element_type: Type[Numeric], - epi_tile: cute.Tile, - *, - loc=None, - ip=None, -) -> SmemLayoutAtomKind: - """Simple heuristics to select the optimal SMEM layout atom for epilog tensors. - - :param layout: The layout enum for the SMEM tensor. - :type layout: LayoutEnum - :param element_type: The element type for the SMEM tensor. - :type element_type: Type[Numeric] - :param epi_tile: The epilogue tile shape. - :type epi_tile: cute.Tile - - :return: The SMEM layout atom kind - :rtype: SmemLayoutAtomKind - """ - # Get the max contiguous tile usable by TMA - tma_shape = tuple( - ( - # assumes get<0>(epi_tile) is coalesced and unit stride - cute.coalesce(cute.right_inverse(x, loc=loc, ip=ip), loc=loc, ip=ip).shape - if isinstance(x, cute.Layout) - else x - ) - for x in epi_tile - ) - - if layout.is_m_major_c(): - # ColMajor C/D (M-major) - return get_smem_layout_atom_ab( - OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip - ) - else: - # RowMajor C/D (N-major) - return get_smem_layout_atom_ab( - OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_smem_layout_epi( - epi_dtype: Type[Numeric], - epi_layout: LayoutEnum, - epi_tile: cute.Tile, - epi_stage: int, - *, - loc=None, - ip=None, -) -> Union[cute.Layout, cute.ComposedLayout]: - """This function helps: - 1. Select the heuristic SMEM layout atom based on the epilog tile shape, - the epilog tensor's majorness, and the element type. - 2. cute.Tile the SMEM layout atom to the epilog tile shape. - 3. Stage the SMEM layout based on the number of stages. - - :param epi_dtype: The element type for the epilog tensor. - :type epi_dtype: Type[Numeric] - :param epi_layout: The layout enum for the epilog tensor. - :type epi_layout: LayoutEnum - :param epi_tile: The epilogue tile shape. - :type epi_tile: cute.cute.Tile - :param epi_stage: The stage of the epilog tensor. - :type epi_stage: int - - :return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog) - :rtype: Union[cute.Layout, cute.ComposedLayout] - """ - - epilog_shape = cute.product_each( - cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip - ) - - c_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_epi( - epi_layout, - epi_dtype, - epi_tile, - loc=loc, - ip=ip, - ), - epi_dtype, - loc=loc, - ip=ip, - ) - epi_smem_layout_staged = cute.tile_to_shape( - c_smem_layout_atom, - cute.append(epilog_shape, epi_stage, loc=loc, ip=ip), - order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)), - loc=loc, - ip=ip, - ) - - return epi_smem_layout_staged - - -@dsl_user_op -def make_trivial_tiled_mma( - ab_dtype: Type[Numeric], - a_leading_mode: OperandMajorMode, - b_leading_mode: OperandMajorMode, - acc_dtype: Type[Numeric], - cta_group: CtaGroup, - mma_tiler_mn: Tuple[int, int], - a_source: OperandSource = OperandSource.SMEM, - *, - loc=None, - ip=None, -) -> cute.TiledMma: - """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. - By default, the MMA atom is created with SMEM operand source for A. - - :param ab_dtype: Data type of operands A and B. - :type ab_dtype: type[Numeric] - :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: tcgen05.OperandMajorMode - :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: tcgen05.OperandMajorMode - :param acc_dtype: Data type of the accumulator. - :type acc_dtype: type[Numeric] - :param cta_group: The CTA group to use. - :type cta_group: tcgen05.CtaGroup - :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. - :type mma_tiler_mn: Tuple[int, int] - :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: OperandSource - - :return: A tiled MMA atom. - :rtype: cute.TiledMma - - :raises TypeError: If the data type is not supported. - """ - - if ab_dtype in {Float16, BFloat16}: - mma_op = MmaF16BF16Op( - ab_dtype, - acc_dtype, - (*mma_tiler_mn, 16), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, - ) - elif ab_dtype in {TFloat32, Float32}: - mma_op = MmaTF32Op( - (*mma_tiler_mn, 8), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, - ) - elif ab_dtype in { - Uint8, - Int8, - }: - mma_op = MmaI8Op( - ab_dtype, - (*mma_tiler_mn, 32), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, - ) - elif ab_dtype in {Float8E4M3FN, Float8E5M2}: - mma_op = MmaFP8Op( - ab_dtype, - acc_dtype, - (*mma_tiler_mn, 32), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, - ) - else: - raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") - - return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) - - -@dsl_user_op -def make_blockscaled_trivial_tiled_mma( - ab_dtype: Type[Numeric], - a_leading_mode: OperandMajorMode, - b_leading_mode: OperandMajorMode, - sf_dtype: Type[Numeric], - sf_vec_size: int, - cta_group: CtaGroup, - mma_tiler_mn: Tuple[int, int], - a_source: OperandSource = OperandSource.SMEM, - *, - loc=None, - ip=None, -) -> cute.TiledMma: - """Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. - By default, the MMA atom is created with SMEM operand source for A. - - :param ab_dtype: Data type of operands A and B. - :type ab_dtype: type[Numeric] - :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: tcgen05.OperandMajorMode - :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: tcgen05.OperandMajorMode - :param sf_dtype: Data type of the Scale Factor. - :type sf_dtype: type[Numeric] - :param sf_vec_size: The vector size of the Scale Factor. - :type sf_vec_size: int - :param cta_group: The CTA group to use. - :type cta_group: tcgen05.CtaGroup - :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. - :type mma_tiler_mn: Tuple[int, int] - :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: OperandSource - - :return: A tiled MMA atom. - :rtype: cute.TiledMma - - :raises TypeError: If the data type is not supported. - """ - if ab_dtype in {Float8E4M3FN, Float8E5M2}: - mma_op = MmaMXF8Op( - ab_dtype, - (*mma_tiler_mn, 32), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, - ) - elif ab_dtype == Float4E2M1FN: - if sf_vec_size == 32: - mma_op = MmaMXF4Op( - (*mma_tiler_mn, 64), - cta_group, - a_source, - ) - elif sf_vec_size == 16: - mma_op = MmaMXF4NVF4Op( - sf_dtype, - (*mma_tiler_mn, 64), - cta_group, - a_source, - ) - else: - raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}") - else: - raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") - - return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) - - -@dsl_user_op -def cluster_shape_to_tma_atom_A( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None -) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: - """ - Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag. - - :param cluster_shape_mnk: The shape of the cluster - :type cluster_shape_mnk: cute.Shape - :param atom_thr_id: The thread ID of the atom - :type atom_thr_id: cute.Layout - - :return: The appropriate TMA copy atom kind - :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp - - :raise ValueError: If the atom_sm_cnt is invalid - :raise ValueError: If the cluster shape is not divisible by the atom SM count - """ - atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) - mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1) - cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) - - if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): - raise ValueError( - f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: - raise ValueError( - f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if atom_sm_cnt == 2 and mcast: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) - elif atom_sm_cnt == 2 and not mcast: - return CopyBulkTensorTileG2SOp(CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return CopyBulkTensorTileG2SOp(CtaGroup.ONE) - - raise ValueError( - f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" - ) - - -@dsl_user_op -def cluster_shape_to_tma_atom_B( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None -) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: - """ - Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag. - - :param cluster_shape_mnk: The shape of the cluster - :type cluster_shape_mnk: cute.Shape - :param atom_thr_id: The thread ID of the atom - :type atom_thr_id: cute.Layout - - :return: The appropriate TMA copy atom kind - :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp - - :raise ValueError: If the atom_sm_cnt is invalid - :raise ValueError: If the cluster shape is not divisible by the atom SM count - """ - atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) - mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt) - cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) - - if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): - raise ValueError( - f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: - raise ValueError( - f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if atom_sm_cnt == 2 and mcast: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) - elif atom_sm_cnt == 2 and not mcast: - return CopyBulkTensorTileG2SOp(CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return CopyBulkTensorTileG2SOp(CtaGroup.ONE) - - raise ValueError( - f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" - ) - - -@dsl_user_op -def cluster_shape_to_tma_atom_SFB( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None -) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: - """ - Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag. - - :param cluster_shape_mnk: The shape of the cluster - :type cluster_shape_mnk: cute.Shape - :param atom_thr_id: The thread ID of the atom - :type atom_thr_id: cute.Layout - - :return: The appropriate TMA copy atom kind - :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp - - :raise ValueError: If the atom_sm_cnt is invalid - :raise ValueError: If the cluster shape is not divisible by the atom SM count - """ - atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) - mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1) - cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) - - if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): - raise ValueError( - f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: - raise ValueError( - f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" - ) - - if atom_sm_cnt == 2: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return CopyBulkTensorTileG2SOp(CtaGroup.ONE) - - raise ValueError( - f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py deleted file mode 100644 index fa1e2eb70e38236d73f435e001fdc160d301c47c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py +++ /dev/null @@ -1,287 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from dataclasses import dataclass, field -from typing import Union - -from cutlass.cutlass_dsl import dsl_user_op - -import cutlass.cute as cute -from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir - - -@dataclass(frozen=True) -class BlockScaledBasicChunk: - """ - The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops. - - This class represents the fixed layout pattern for scale factors used in - tcgen05 BlockScaled MMA Ops. The layout is determined by the - instruction specification and cannot be modified. - See `PTX documentation `. - """ - - sf_vec_size: int - major_mode: OperandMajorMode = OperandMajorMode.K - _layout: cute.Layout = field(init=False, repr=False) - - def __post_init__(self) -> None: - if self.major_mode == OperandMajorMode.K: - # K-major layout: (AtomMN, AtomK) - atom_shape = ((32, 4), (self.sf_vec_size, 4)) - atom_stride = ((16, 4), (0, 1)) - else: - # MN-major layout: (AtomK, AtomMN) - atom_shape = ((self.sf_vec_size, 4), (32, 4)) - atom_stride = ((0, 1), (16, 4)) - - object.__setattr__( - self, "_layout", cute.make_layout(atom_shape, stride=atom_stride) - ) - - @property - def layout(self) -> cute.Layout: - """ - Get the layout for this block scaled chunk. - - :return: The layout representing the scale factor atom - :rtype: cute.Layout - """ - return self._layout - - -@dsl_user_op -def tile_atom_to_shape_SF( - Shape: cute.Shape, - sf_vec_size: int, - *, - loc=None, - ip=None, -) -> cute.Layout: - """ - A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. - - :param Shape: The shape of the A/B tensor - :param sf_vec_size: Scale factor vector size - - :return: The layout of the SFA/SFB tensor - :rtype: cute.Layout - """ - # ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL) - sf_layout = cute.tile_to_shape( - BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3) - ) - return sf_layout - - -@dsl_user_op -def make_smem_layout_sfa( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - sf_vec_size: int, - num_stages: int, - *, - loc=None, - ip=None, -) -> cute.Layout: - """ - Make smem layout for SFA based on: - 1. BlockScaledBasicChunk - 2. MMA tiler shape - 3. Scale factor vector size - 4. Number of stages - - :param tiled_mma: The tiled MMA - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The mma tiler shape - :type mma_tiler_mnk: cute.Tile - :param sf_vec_size: The scale factor vector size - :type sf_vec_size: int - :param num_stages: The number of stages - :type num_stages: int - - :return: Smem layout for SFA - :rtype: cute.Layout - """ - # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) - sfa_tile_shape = ( - mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), - mma_tiler_mnk[2], - ) - - # ((Atom_M, Rest_M),(Atom_K, Rest_K)) - smem_layout = cute.tile_to_shape( - BlockScaledBasicChunk(sf_vec_size).layout, - sfa_tile_shape, - (2, 1), - ) - - mma_tile_inst_k = 4 - # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) - sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k)) - # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) - smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) - - atom_m = 128 - tiler_inst = ((atom_m, sf_vec_size),) - # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) - smem_layout = cute.logical_divide(smem_layout, tiler_inst) - - # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) - sfa_smem_layout_staged = cute.append( - smem_layout, - cute.make_layout( - num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) - ), - ) - - return sfa_smem_layout_staged - - -@dsl_user_op -def make_smem_layout_sfb( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - sf_vec_size: int, - num_stages: int, - *, - loc=None, - ip=None, -) -> cute.Layout: - """ - Make smem layout for SFB based on: - 1. BlockScaledBasicChunk - 2. MMA tiler shape - 3. Scale factor vector size - 4. Number of stages - - :param tiled_mma: The tiled MMA - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The mma tiler shape - :type mma_tiler_mnk: cute.Tile - :param sf_vec_size: The scale factor vector size - :type sf_vec_size: int - :param num_stages: The number of stages - :type num_stages: int - - :return: Smem layout for SFA - :rtype: cute.Layout - """ - # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) - sfb_tile_shape = ( - cute.round_up(mma_tiler_mnk[1], 128), - mma_tiler_mnk[2], - ) - - # ((Atom_N, Rest_N),(Atom_K, Rest_K)) - smem_layout = cute.tile_to_shape( - BlockScaledBasicChunk(sf_vec_size).layout, - sfb_tile_shape, - (2, 1), - ) - - mma_tile_inst_k = 4 - # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) - sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k)) - # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) - smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) - - atom_n = 128 - tiler_inst = ((atom_n, sf_vec_size),) - # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) - smem_layout = cute.logical_divide(smem_layout, tiler_inst) - - # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) - sfb_smem_layout_staged = cute.append( - smem_layout, - cute.make_layout( - num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) - ), - ) - - return sfb_smem_layout_staged - - -@dsl_user_op -def make_tmem_layout_sfa( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - sf_vec_size: int, - smem_layout: cute.Layout, - *, - loc=None, - ip=None, -) -> cute.Layout: - """Make tmem layout for SFA based on: - 1. SFA smem layout per stage - 2. Cta tile shape m - 3. tiled MMA atom thr size - 4. Scale factor vector size - - :param tiled_mma: The tiled MMA - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The mma tiler shape - :type mma_tiler_mnk: cute.Tile - :param sf_vec_size: The scale factor vector size - :type sf_vec_size: int - :param smem_layout: The smem layout of SFA per stage - :type smem_layout: cute.Layout - - :return: TMEM layout for SFA - :rtype: cute.Layout - """ - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size - - sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( - smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size - ) - return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip) - - -@dsl_user_op -def make_tmem_layout_sfb( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: cute.Tile, - sf_vec_size: int, - smem_layout: cute.Layout, - *, - loc=None, - ip=None, -) -> cute.Layout: - """Make tmem layout for SFB based on: - 1. SFB smem layout per stage - 2. Cta tile shape m - 3. tiled MMA atom thr size - 4. Scale factor vector size - - :param tiled_mma: The tiled MMA - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The mma tiler shape - :type mma_tiler_mnk: cute.Tile - :param sf_vec_size: The scale factor vector size - :type sf_vec_size: int - :param smem_layout: The smem layout of SFB per stage - :type smem_layout: cute.Layout - - :return: TMEM layout for SFB - :rtype: cute.Layout - """ - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size - - sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( - smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size - ) - return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py deleted file mode 100644 index 5853c56c84f6fc02e911537147fa03b6b4566117..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from functools import partial -from typing import Tuple - -import cutlass.cute as cute -from cutlass.cutlass_dsl import T, dsl_user_op, while_generate - -from cutlass._mlir import ir -from cutlass._mlir.dialects import arith, llvm, nvvm, scf -from cutlass._mlir.dialects.nvvm import ( - MemOrderKind, - MemScopeKind, - AtomicOpKind, -) -from cutlass.cute.typing import Pointer, Int32, Boolean - - -@dsl_user_op -def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: - return nvvm.atomicrmw( - T.i32(), - AtomicOpKind.ADD, - dst_ptr.llvm_ptr, - val.ir_value(loc=loc, ip=ip), - mem_order=MemOrderKind.RELAXED, - syncscope=MemScopeKind.SYS, - loc=loc, - ip=ip, - ) - - -@cute.jit -def ld_bypass(input_tensor: cute.Tensor): - fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type) - copy_atom_load = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - input_tensor.element_type, - memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, - memory_scope=cute.nvgpu.common.MemoryScope.SYS, - ) - cute.copy(copy_atom_load, input_tensor, fragment) - vals = fragment.load() - return vals - -@cute.jit -def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None: - """ - wait on a spin lock until the expected count is reached. - """ - res = 0 - while res != expect_count: - res = nvvm.atomicrmw( - T.i32(), - AtomicOpKind.CAS, - lock_ptr.llvm_ptr, - Int32(0).ir_value(loc=loc, ip=ip), - b=Int32(expect_count).ir_value(loc=loc, ip=ip), - mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED, - syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS - ) - - -@dsl_user_op -def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: - """ - add 1 to the multimem address - """ - llvm.inline_asm( - None, - [mc_ptr.toint().ir_value()], - "multimem.red.release.sys.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - -@dsl_user_op -def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: - """ - add 1 to the multimem address - """ - llvm.inline_asm( - None, - [mc_ptr.toint().ir_value()], - "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: - """ - arrive a spin lock when the lock_ptr is a multimem address. - """ - multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) - - -def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None : - """ - barrier for inter-gpu sm-wise - """ - bidx, bidy, bidz = cute.arch.block_idx() - bdimx, bdimy, _ = cute.arch.grid_dim() - pid = bidx + bidy * bdimx + bidz * bdimx * bdimy - multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) - spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip) - - -@dsl_user_op -def multimem_ld_reduce_base( - mc_ptr: Pointer, - *, - ptx_string: str = "", - loc=None, - ip=None, -) -> Tuple[Int32, Int32, Int32, Int32]: - # ld reduce 8xf16 elts - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() - return_struct = llvm.inline_asm( - ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), - [mc_ptr_int], - ptx_string, - "=r,=r,=r,=r,l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)] - return return_regs[0], return_regs[1], return_regs[2], return_regs[3] - - -multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];") - - -@dsl_user_op -def multimem_st_4xb32( - mc_ptr: Pointer, - x: Int32, - y: Int32, - z: Int32, - w: Int32, - *, - loc=None, - ip=None, -) -> None: - # st 4x32 bits of data - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() - llvm.inline_asm( - T.i32(), - [mc_ptr_int, x, y, z, w], - "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", - "=r,l,r,r,r,r", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py deleted file mode 100644 index a51bae62963bd482fd590f824a4bc1c8564ece0e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py +++ /dev/null @@ -1,466 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import List, Tuple - -import cutlass.cute as cute -from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values -from cutlass._mlir import ir - -from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams - - -class GroupSearchResult: - """ - The result of the group search for grouped gemm. - - :param group_idx: The result group index - :type group_idx: Int32 - :param cta_tile_idx_m: CTA tile index along M dimension after rasterization - :type cta_tile_idx_m: Int32 - :param cta_tile_idx_n: CTA tile index along N dimension after rasterization - :type cta_tile_idx_n: Int32 - :param problem_shape_m: The M dimension of the gemm problem - :type problem_shape_m: Int32 - :param problem_shape_n: The N dimension of the gemm problem - :type problem_shape_n: Int32 - :param problem_shape_k: The K dimension of the gemm problem - :type problem_shape_k: Int32 - :param cta_tile_count_k: Number of tiles along K dimension - :type cta_tile_count_k: Int32 - """ - - def __init__( - self, - group_idx: Int32, - cta_tile_idx_m: Int32, - cta_tile_idx_n: Int32, - problem_shape_m: Int32, - problem_shape_n: Int32, - problem_shape_k: Int32, - cta_tile_count_k: Int32, - ) -> None: - self.group_idx = group_idx - self.cta_tile_idx_m = cta_tile_idx_m - self.cta_tile_idx_n = cta_tile_idx_n - self.problem_shape_m = problem_shape_m - self.problem_shape_n = problem_shape_n - self.problem_shape_k = problem_shape_k - self.cta_tile_count_k = cta_tile_count_k - - def __extract_mlir_values__(self) -> List[ir.Value]: - values = extract_mlir_values(self.group_idx) - values.extend(extract_mlir_values(self.cta_tile_idx_m)) - values.extend(extract_mlir_values(self.cta_tile_idx_n)) - values.extend(extract_mlir_values(self.problem_shape_m)) - values.extend(extract_mlir_values(self.problem_shape_n)) - values.extend(extract_mlir_values(self.problem_shape_k)) - values.extend(extract_mlir_values(self.cta_tile_count_k)) - return values - - def __new_from_mlir_values__(self, values: List[ir.Value]) -> "GroupSearchResult": - assert len(values) == 7 - return GroupSearchResult(*tuple(values)) - - -class GroupedGemmGroupSearchState: - """ - The state of group index search for grouped gemm. - - The state will be initialized once and updated in every round of group index search. - - :param start_group_idx: The group idx to start the search with - :type start_group_idx: Int32 - :param tile_count_prev_group: Number of tiles before the matched group - :type tile_count_prev_group: Int32 - :param tile_count_searched: Number of tiles we have searched. When the matched group is found, - it records the number of tiles including the matched group - :type tile_count_searched: Int32 - """ - - def __init__( - self, - start_group_idx: Int32, - tile_count_prev_group: Int32, - tile_count_searched: Int32, - ) -> None: - self.start_group_idx = start_group_idx - self.tile_count_prev_group = tile_count_prev_group - self.tile_count_searched = tile_count_searched - - def __extract_mlir_values__(self) -> List[ir.Value]: - values = extract_mlir_values(self.start_group_idx) - values.extend(extract_mlir_values(self.tile_count_prev_group)) - values.extend(extract_mlir_values(self.tile_count_searched)) - return values - - def __new_from_mlir_values__( - self, values: List[ir.Value] - ) -> "GroupedGemmGroupSearchState": - start_group_idx = new_from_mlir_values(self.start_group_idx, [values[0]]) - tile_count_prev_group = new_from_mlir_values( - self.tile_count_prev_group, [values[1]] - ) - tile_count_searched = new_from_mlir_values( - self.tile_count_searched, [values[2]] - ) - return GroupedGemmGroupSearchState( - start_group_idx, tile_count_prev_group, tile_count_searched - ) - - -def create_initial_search_state() -> GroupedGemmGroupSearchState: - """ - Create an initial search state for grouped gemm. - - :return: A new search state with initial values - :rtype: GroupedGemmGroupSearchState - """ - return GroupedGemmGroupSearchState( - start_group_idx=Int32(0), - tile_count_prev_group=Int32(0), - tile_count_searched=Int32(0), - ) - - -class GroupedGemmTileSchedulerHelper: - """ - A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm. - - :param group_count: Number of groups in current grouped gemm problem - :type group_count: int - :param tile_sched_params: Parameter used to create the tile scheduler this helper works with - :type tile_sched_params: PersistentTileSchedulerParams - :param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k) - :type cluster_tile_shape_mnk: tuple[int, int, int] - :param search_state: The initial search state - :type search_state: GroupedGemmGroupSearchState - """ - - def __init__( - self, - group_count: int, - tile_sched_params: PersistentTileSchedulerParams, - cluster_tile_shape_mnk: tuple[int, int, int], - search_state: GroupedGemmGroupSearchState, - ) -> None: - self.tile_sched_params = tile_sched_params - self.group_count = group_count - self.lane_idx = cute.arch.lane_idx() - self.cluster_tile_shape_mnk = cluster_tile_shape_mnk - self.search_state = search_state - - def __extract_mlir_values__(self) -> List[ir.Value]: - values = extract_mlir_values(self.tile_sched_params) - values.extend(extract_mlir_values(self.search_state)) - return values - - def __new_from_mlir_values__( - self, values: List[ir.Value] - ) -> "GroupedGemmTileSchedulerHelper": - tile_sched_params = new_from_mlir_values(self.tile_sched_params, values) - search_state = new_from_mlir_values(self.search_state, values[1:]) - return GroupedGemmTileSchedulerHelper( - self.group_count, - tile_sched_params, - self.cluster_tile_shape_mnk, - search_state, - ) - - def delinearize_z( - self, - cta_tile_coord: tuple, - problem_shape_mnkl: cute.Tensor, - ) -> GroupSearchResult: - """ - Delinearize the linear z index and return GroupSearchResult. - - This function should be used by warps that need to know the CTA tile index on M and N dimensions. - - :param cta_tile_coord: The raw CTA coordinate from tile scheduler - :type cta_tile_coord: tuple of Int32 - :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for each group - :type problem_shape_mnkl: cute.Tensor - :return: The search result containing group index and tile coordinates - :rtype: GroupSearchResult - """ - # delinear the z coord - linear_idx = cta_tile_coord[2] - group_idx, problem_mnkl = self._group_search_and_load_problem_shape( - linear_idx, - problem_shape_mnkl, - self.search_state.start_group_idx, - self.search_state.tile_count_prev_group, - ) - # linear index local to current group - cluster_tile_idx_in_current_group = ( - linear_idx - self.search_state.tile_count_prev_group - ) - cluster_count_m, cluster_count_n, cluster_count_k = cute.ceil_div( - (problem_mnkl[0], problem_mnkl[1], problem_mnkl[2]), - ( - self.cluster_tile_shape_mnk[0], - self.cluster_tile_shape_mnk[1], - self.cluster_tile_shape_mnk[2], - ), - ) - # decompose to get indices on M and N - cta_tile_idx_m, cta_tile_idx_n = self._compute_cta_tile_coord( - cluster_tile_idx_in_current_group, - cta_tile_coord, - cluster_count_m, - cluster_count_n, - ) - return GroupSearchResult( - group_idx, - cta_tile_idx_m, - cta_tile_idx_n, - problem_mnkl[0], - problem_mnkl[1], - problem_mnkl[2], - cluster_count_k, - ) - - def search_cluster_tile_count_k( - self, - cta_tile_coord: tuple, - problem_shape_mnkl: cute.Tensor, - ) -> Tuple[Int32, Int32]: - """ - Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group. - - This function should be used by warps that are only interested in the number of tiles along K dimension. - - :param cta_tile_coord: The raw CTA coordinate from tile scheduler - :type cta_tile_coord: tuple of Int32 - :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups - :type problem_shape_mnkl: cute.Tensor - :return: A tuple containing cluster count along K dimension and the group index - :rtype: Tuple[Int32, Int32] - """ - group_idx, problem_mnk = self._group_search_and_load_problem_shape( - cta_tile_coord[2], - problem_shape_mnkl, - self.search_state.start_group_idx, - self.search_state.tile_count_prev_group, - ) - cluster_count_k = ( - problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1 - ) // self.cluster_tile_shape_mnk[2] - return cluster_count_k, group_idx - - @cute.jit - def _prefix_sum(self, value_per_thread: Int32) -> Int32: - """ - Perform prefix sum within a full warp. - - :param value_per_thread: The value for this thread to contribute to the prefix sum - :type value_per_thread: Int32 - :return: The prefix sum result for this thread - :rtype: Int32 - """ - clamp_value = 0 - idx = 1 - sum_per_thread = value_per_thread - while idx < cute.arch.WARP_SIZE: - value = cute.arch.shuffle_sync_up( - sum_per_thread, idx, mask_and_clamp=clamp_value - ) - if self.lane_idx >= idx: - sum_per_thread += value - idx = idx << 1 - return sum_per_thread - - def _get_problem_for_group( - self, problem_shape_mnkl: cute.Tensor, group_idx: Int32 - ) -> cute.Tensor: - """ - Load gemm problem (m,n,k,l) for the specified group from global memory to register. - - :param problem_shape_mnkl: Tensor in global memory with layout (group_count, 4):(4, 1) - :type problem_shape_mnkl: cute.Tensor - :param group_idx: The index of the group to load - :type group_idx: Int32 - :return: The problem shape tensor for the specified group - :rtype: cute.Tensor - """ - cur_problem_mnkl = cute.make_fragment( - cute.make_layout(4), problem_shape_mnkl.element_type - ) - cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl) - return cur_problem_mnkl - - def _get_cluster_tile_count_mn(self, problem_shape: cute.Tensor) -> Int32: - """ - Compute total cluster count. - - :param problem_shape: Tensor containing problem shape (m, n, k, l) - :type problem_shape: cute.Tensor - :return: The total cluster tile count for M and N dimensions - :rtype: Int32 - """ - cur_ntile_m = ( - problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 - ) // self.cluster_tile_shape_mnk[0] - cur_ntile_n = ( - problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 - ) // self.cluster_tile_shape_mnk[1] - cur_ntile_mn = cur_ntile_m * cur_ntile_n - return cur_ntile_mn - - def _compute_cta_tile_coord( - self, - cluster_tile_idx: Int32, - cta_tile_coord_in_cluster: tuple, - cluster_tile_count_m: Int32, - cluster_tile_count_n: Int32, - ) -> tuple: - """ - Compute CTA tile indices along M and N dimensions based on the linear index within a group. - - It uses the AlongM mode to decompose the linear index onto M and N dimensions. - - :param cluster_tile_idx: The linear index within a group - :type cluster_tile_idx: Int32 - :param cta_tile_coord_in_cluster: CTA indices along M and N dimensions within a cluster - :type cta_tile_coord_in_cluster: tuple of Int32 - :param cluster_tile_count_m: The number of clusters along M dimension of the matched group - :type cluster_tile_count_m: Int32 - :param cluster_tile_count_n: The number of clusters along N dimension of the matched group - :type cluster_tile_count_n: Int32 - :return: A tuple containing CTA tile indices along M and N dimensions - :rtype: tuple of (Int32, Int32) - """ - cluster_layout_mn = cute.make_layout( - (cluster_tile_count_m, cluster_tile_count_n) - ) - (mi, ni) = cluster_layout_mn.get_hier_coord(cluster_tile_idx) - cta_tile_idx_m = ( - mi * self.tile_sched_params.cluster_shape_mn[0] - + cta_tile_coord_in_cluster[0] - ) - cta_tile_idx_n = ( - ni * self.tile_sched_params.cluster_shape_mn[1] - + cta_tile_coord_in_cluster[1] - ) - return (cta_tile_idx_m, cta_tile_idx_n) - - @cute.jit - def _group_search( - self, - linear_idx: Int32, - problem_shape_mnkl: cute.Tensor, - init_group_idx: Int32, - init_tile_count_searched: Int32, - ) -> GroupedGemmGroupSearchState: - """ - Search which group the linear index belongs to. - - :param linear_idx: The linear index to be decomposed - :type linear_idx: Int32 - :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups - :type problem_shape_mnkl: cute.Tensor - :param init_group_idx: The group idx to start the search with - :type init_group_idx: Int32 - :param init_tile_count_searched: The number of tiles we have searched - :type init_tile_count_searched: Int32 - :return: The updated search state - :rtype: GroupedGemmGroupSearchState - """ - c_0 = Int32(0).ir_value() - last_lane_idx = cute.arch.WARP_SIZE - 1 - - tile_count_searched = init_tile_count_searched - start_group_idx = init_group_idx - not_found = linear_idx >= tile_count_searched - tile_count_prev_group = self.search_state.tile_count_prev_group - while not_found: - # get group to search for current lane - cur_group_idx = start_group_idx + self.lane_idx - # check if the group to be checked is out of range - inside_group_bound = cur_group_idx < self.group_count - cur_ntile_mn = c_0 - if inside_group_bound: - # get problem size of current group - cur_problem_mnkl = self._get_problem_for_group( - problem_shape_mnkl, cur_group_idx - ) - cur_ntile_mn = self._get_cluster_tile_count_mn(cur_problem_mnkl) - # compute tile count from beginning to current group(included) - total_cluster_tile_count_ps_per_thread = self._prefix_sum(cur_ntile_mn) - cluster_tile_count_end_per_thread = ( - total_cluster_tile_count_ps_per_thread + tile_count_searched - ) - - group_not_in_window = linear_idx >= cluster_tile_count_end_per_thread - hitted_group_idx_in_search_window = cute.arch.popc( - cute.arch.vote_ballot_sync(group_not_in_window) - ) - not_found = hitted_group_idx_in_search_window == cute.arch.WARP_SIZE - start_group_idx = hitted_group_idx_in_search_window + start_group_idx - hit_the_1st_problem_in_search_window = ( - hitted_group_idx_in_search_window == c_0 - ) - tile_count_prev_group = tile_count_searched - if hit_the_1st_problem_in_search_window == False: - tile_count_prev_group = cute.arch.shuffle_sync( - cluster_tile_count_end_per_thread, - hitted_group_idx_in_search_window - 1, - ) - - # If no matched group, then get new_cluster_tile_count_end from last lane - # Otherwise, get new_cluster_tile_count_end from the hitted group - lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window - if not_found: - lane_idx_for_cluster_tile_count_end = last_lane_idx - tile_count_searched = cute.arch.shuffle_sync( - cluster_tile_count_end_per_thread, - lane_idx_for_cluster_tile_count_end, - ) - - return GroupedGemmGroupSearchState( - start_group_idx, - tile_count_prev_group, - tile_count_searched, - ) - - def _group_search_and_load_problem_shape( - self, - linear_idx: Int32, - problem_shape_mnkl: cute.Tensor, - start_group_idx: Int32, - tile_count_searched: Int32, - ) -> Tuple[Int32, cute.Tensor]: - """ - Perform group search and load problem shape for the matched group. - - :param linear_idx: The linear index to be decomposed - :type linear_idx: Int32 - :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups - :type problem_shape_mnkl: cute.Tensor - :param start_group_idx: The group idx to start the search with - :type start_group_idx: Int32 - :param tile_count_searched: The number of tiles we have searched - :type tile_count_searched: Int32 - :return: A tuple containing the final group index and the problem shape tensor - :rtype: Tuple[Int32, cute.Tensor] - """ - self.search_state = self._group_search( - linear_idx, - problem_shape_mnkl, - start_group_idx, - tile_count_searched, - ) - # get final group search state - final_group_idx = self.search_state.start_group_idx - # let's revisit if it's better to broadcast problem_shape_mnk in group_search - problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx) - return final_group_idx, problem_mnkl diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py deleted file mode 100644 index e86fcbefc86fbc7da333735fa2cebbd3af47f39e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py +++ /dev/null @@ -1,174 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from cuda.bindings import driver, nvrtc - -import cutlass.cute as cute - -""" -This class is used to get the hardware info of given GPU device. -It provides methods to get the max active clusters for given cluster size. - -Prerequisite: -- CUDA driver is initialized via `driver.cuInit` or other CUDA APIs. -- CUDA context is created via `driver.cuCtxCreate` or other CUDA APIs. - -""" - - -class HardwareInfo: - """ - device_id: CUDA device ID to get the hardware info. - """ - - def __init__(self, device_id: int = 0): - count = self._checkCudaErrors(driver.cuDeviceGetCount()) - if device_id >= count: - raise ValueError( - f"Device ID {device_id} is out of range for device count {count}" - ) - self.device_id = device_id - self.device = self._checkCudaErrors(driver.cuDeviceGet(device_id)) - self.context = self._checkCudaErrors(driver.cuCtxGetCurrent()) - self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion()) - - # Getting the max active clusters for a given cluster size - def get_max_active_clusters(self, cluster_size: int) -> int: - self._get_device_function() - if self._cuda_driver_version_lt(11, 8): - raise RuntimeError( - "CUDA Driver version < 11.8, cannot get _max_active_clusters" - ) - if cluster_size <= 0 or cluster_size > 32: - raise ValueError( - f"Cluster size must be between 1 and 32, {cluster_size} is not supported" - ) - - max_shared_memory_per_block = self._checkCudaErrors( - driver.cuDeviceGetAttribute( - driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - self.device, - ) - ) - self._checkCudaErrors( - driver.cuFuncSetAttribute( - self.kernel, - driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - max_shared_memory_per_block, - ) - ) - max_dynamic_shared_memory = self._checkCudaErrors( - driver.cuOccupancyAvailableDynamicSMemPerBlock( - self.kernel, 1, 1 # numBlocks # blockSize - ) - ) - max_active_blocks = self._checkCudaErrors( - driver.cuOccupancyMaxActiveBlocksPerMultiprocessor( - self.kernel, 1, max_dynamic_shared_memory # blockSize, - ) - ) - # allow non-portable cluster size to support detection of non-portable cluster size - self._checkCudaErrors( - driver.cuFuncSetAttribute( - self.kernel, - driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1, - ) - ) - # prepare launch configuration - launch_config = driver.CUlaunchConfig() - launch_config.blockDimX = 128 - launch_config.blockDimY = 1 - launch_config.blockDimZ = 1 - launch_config.sharedMemBytes = max_dynamic_shared_memory - launch_config.numAttrs = 1 - # max possible cluster size is 32 - cluster_dims_attr = driver.CUlaunchAttribute() - cluster_dims_attr.id = ( - driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - ) - value = driver.CUlaunchAttributeValue() - value.clusterDim.x = cluster_size - value.clusterDim.y = 1 - value.clusterDim.z = 1 - cluster_dims_attr.value = value - launch_config.attrs = [cluster_dims_attr] - launch_config.gridDimX = cluster_size - launch_config.gridDimY = max_active_blocks - launch_config.gridDimZ = 1 - - num_clusters = self._checkCudaErrors( - driver.cuOccupancyMaxActiveClusters(self.kernel, launch_config) - ) - return num_clusters - - def get_l2_cache_size_in_bytes(self) -> int: - return self._checkCudaErrors( - driver.cuDeviceGetAttribute( - driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, - self.device, - ) - ) - - def get_device_multiprocessor_count(self) -> int: - return self._checkCudaErrors( - driver.cuDeviceGetAttribute( - driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, - self.device, - ) - ) - - def _checkCudaErrors(self, result) -> None: - if result[0].value: - raise RuntimeError( - "CUDA error code={}({})".format( - result[0].value, self._cudaGetErrorEnum(result[0]) - ) - ) - # CUDA APIs always return the status as the first element of the result tuple - if len(result) == 1: - return None - elif len(result) == 2: - return result[1] - else: - return result[1:] - - def _cudaGetErrorEnum(self, error) -> str: - if isinstance(error, driver.CUresult): - err, name = driver.cuGetErrorName(error) - return name if err == driver.CUresult.CUDA_SUCCESS else "" - elif isinstance(error, nvrtc.nvrtcResult): - return nvrtc.nvrtcGetErrorString(error)[1] - else: - raise RuntimeError("Unknown error type: {}".format(error)) - - def _cuda_driver_version_ge(self, major: int, minor: int) -> bool: - return self.driver_version >= (major * 1000 + 10 * minor) - - def _cuda_driver_version_lt(self, major: int, minor: int) -> bool: - return not self._cuda_driver_version_ge(major, minor) - - @cute.kernel - def _empty_kernel(self): - return - - @cute.jit - def _host_function(self): - self._empty_kernel().launch( - grid=[1, 1, 1], - block=[1, 1, 1], - ) - - # get a empty kernel to compute occupancy - def _get_device_function(self) -> None: - self.compiled_kernel = cute.compile(self._host_function) - self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module - self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py deleted file mode 100644 index 4cd2bae3de66983dc5bf7883305f6a926b3c0d72..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Type, Tuple -from enum import Enum -from typing_extensions import deprecated -import warnings - -from cutlass.utils.layout import LayoutEnum -from cutlass.cutlass_dsl import ( - Float16, - BFloat16, - Float8E5M2, - Float8E4M3FN, - Numeric, - NumericMeta, - dsl_user_op, -) - -import cutlass -import cutlass.cute as cute -from cutlass.cute.nvgpu.common import CopyUniversalOp -from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp -from cutlass.cute.nvgpu.warpgroup import ( - MmaF16BF16Op, - MmaF8Op, - OperandMajorMode, - OperandSource, -) - - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, -} - - -@dsl_user_op -def sm90_get_smem_store_op( - layout_d: LayoutEnum, - elem_ty_d: Type[Numeric], - elem_ty_acc: Type[Numeric], - *, - loc=None, - ip=None, -) -> cute.CopyAtom: - """ - Selects the largest vectorized smem store atom available subject to constraint of gmem layout. - - Parameters: - ----------- - layout_d : LayoutEnum - The layout enum of the output tensor D. - - elem_ty_d : Type[Numeric] - The element type for output tensor D. - - elem_ty_acc : Type[Numeric] - The element type for accumulator. - - Returns: - -------- - Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. - """ - - def validate_type(ty, ty_name): - if not isinstance(ty, NumericMeta): - raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") - - validate_type(elem_ty_d, "elem_ty_d") - validate_type(elem_ty_acc, "elem_ty_acc") - - is_m_major = layout_d.is_m_major_c() - - if elem_ty_d.width == 16: - return cute.make_copy_atom( - StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip - ) - else: - return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) - - -def make_trivial_tiled_mma( - a_dtype: Type[Numeric], - b_dtype: Type[Numeric], - a_leading_mode: OperandMajorMode, - b_leading_mode: OperandMajorMode, - acc_dtype: Type[Numeric], - atom_layout_mnk: Tuple[int, int, int], - tiler_mn: Tuple[int, int], - a_source: OperandSource = OperandSource.SMEM, - *, - loc=None, - ip=None, -) -> cute.TiledMma: - """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. - By default, the MMA atom is created with SMEM operand source for A. - - :param a_dtype: Data type of operand A. - :type a_dtype: type[Numeric] - :param b_dtype: Data type of operand B. - :type b_dtype: type[Numeric] - :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: warpgroup.OperandMajorMode - :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: warpgroup.OperandMajorMode - :param acc_dtype: Data type of the accumulator. - :type acc_dtype: type[Numeric] - :param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads. - :type atom_layout_mnk: Tuple[int, int, int] - :param tiler_mn: The shape (M, N) of the cta tiler. - :type tiler_mn: Tuple[int, int] - - :return: A tiled MMA atom. - :rtype: cute.TiledMma - - :raises TypeError: If the data type is not supported. - """ - - if a_dtype in {Float16, BFloat16}: - if cutlass.const_expr(a_dtype != b_dtype): - raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}") - if cutlass.const_expr(a_dtype.width != b_dtype.width): - raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}") - - mma_op = MmaF16BF16Op( - a_dtype, - acc_dtype, - (*tiler_mn, 16), - a_source, - a_leading_mode, - b_leading_mode, - ) - elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in { - Float8E4M3FN, - Float8E5M2, - }: - mma_op = MmaF8Op( - a_dtype, - b_dtype, - acc_dtype, - (*tiler_mn, 32), - a_source, - a_leading_mode, - b_leading_mode, - ) - else: - raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}") - - return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk) - -def get_smem_layout_atom( - layout: LayoutEnum, - element_type: Type[Numeric], - major_mode_size: int, - *, - loc=None, - ip=None, -): - """Select the optimal shared memory layout atom based on parameters. - - :param layout: Layout enum of the tensor - :type layout: LayoutEnum - :param element_type: Data type of the elements - :type element_type: type[cutlass.Numeric] - :param major_mode_size: Size of the major mode dimension - :type major_mode_size: int - - :return: Selected shared memory layout atom kind - :rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind - """ - assert major_mode_size % 8 == 0 - sw128_num_contiguous_bits = 1024 - sw64_num_contiguous_bits = 512 - sw32_num_contiguous_bits = 256 - major_mode_size_bits = major_mode_size * element_type.width - if layout.sm90_mma_major_mode() == OperandMajorMode.MN: - if major_mode_size_bits % sw128_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128 - if major_mode_size_bits % sw64_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64 - if major_mode_size_bits % sw32_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32 - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER - if major_mode_size_bits % sw128_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128 - if major_mode_size_bits % sw64_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64 - if major_mode_size_bits % sw32_num_contiguous_bits == 0: - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32 - return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py deleted file mode 100644 index 4560c266cf9930ac024adeaa94859d06ecf3650a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from enum import Enum - -import cutlass.cute as cute -from cutlass.cute.nvgpu import warpgroup -from cutlass.cute.nvgpu import tcgen05 - - -class LayoutEnum(Enum): - ROW_MAJOR = "row_major" - COL_MAJOR = "col_major" - - def mma_major_mode(self): - return ( - tcgen05.OperandMajorMode.K - if self == LayoutEnum.ROW_MAJOR - else tcgen05.OperandMajorMode.MN - ) - - def sm90_mma_major_mode(self): - return ( - warpgroup.OperandMajorMode.K - if self == LayoutEnum.ROW_MAJOR - else warpgroup.OperandMajorMode.MN - ) - - def is_n_major_c(self): - return self == LayoutEnum.ROW_MAJOR - - def is_m_major_c(self): - return self == LayoutEnum.COL_MAJOR - - @staticmethod - def from_tensor(tensor: cute.Tensor) -> "LayoutEnum": - ret = None - if tensor.leading_dim == 1: - ret = LayoutEnum.ROW_MAJOR - elif tensor.leading_dim == 0: - ret = LayoutEnum.COL_MAJOR - else: - raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}") - - return ret - - -__all__ = ["LayoutEnum"] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py deleted file mode 100644 index 2500c06e1808bc06db5decce88e8ebf7837f17d0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ /dev/null @@ -1,184 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Type, Union, overload - -from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL - -import cutlass.cute as cute -from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size - - -class SmemAllocator: - """A class for managing shared memory allocation on GPU. - - This class manages a chunk of shared memory and provides APIs for sub-allocation - inside the chunk. - - :ivar _base: The current base address of the shared memory as an i8 typed dynamic value. - :type _base: cute.Pointer - :ivar _allocated_bytes: The total number of bytes allocated in shared memory. - :type _allocated_bytes: int - - .. note:: - This class is responsible for managing the allocation of tensors in shared memory. - The base pointer is aligned to 1024 bytes upon initialization. - """ - - def __init__(self): - """Initialize the SmemAllocator instance. - - Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes. - """ - self._base = get_dyn_smem(Int8, alignment=1024) - self._allocated_bytes = 0 - CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) - - @overload - def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ... - - @overload - def allocate( - self, size_or_type: cute.struct, byte_alignment: int - ) -> cute.Pointer: ... - - def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer: - """Allocate a block of memory with specified size and alignment. - - This method adjusts the base pointer to ensure proper alignment and updates - the internal state to track allocated memory. - - :param size_or_type: The number of bytes to allocate or a struct class - :type size_or_type: Union[int, cute.struct] - :param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment) - :type byte_alignment: int, optional - :return: Pointer to the start of the allocated memory block or struct instance - :rtype: cute.Pointer - :raises ValueError: If size is negative or alignment is less than 1 - :raises RuntimeError: If allocation would exceed available shared memory - """ - if isinstance(size_or_type, cute.struct): - alignment = max(byte_alignment, size_or_type.__alignof__()) - base_ptr = self.allocate(size_or_type.__sizeof__(), alignment) - return size_or_type(base_ptr) - - num_bytes = size_or_type - if num_bytes < 0: - raise ValueError("num_bytes must be non-negative") - if byte_alignment < 1: - raise ValueError("byte_alignment must be at least 1") - - self._base = self._base.align(byte_alignment) - ptr = self._base - self._base += num_bytes - if self._allocated_bytes % byte_alignment != 0: - self._allocated_bytes += ( - byte_alignment - self._allocated_bytes % byte_alignment - ) - self._allocated_bytes += num_bytes - - # Check bounds against available dynamic shared memory - cute.testing.assert_( - self._allocated_bytes <= get_dyn_smem_size(), - f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. " - f"Allocated bytes: {self._allocated_bytes} bytes. " - f"Please reduce the allocation or set a larger smem size in kernel launch.", - ) - return ptr - - def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): - """Allocate an array of elements in shared memory. - - :param element_type: The type of elements to allocate - :type element_type: Type[Numeric] - :param num_elems: Number of elements to allocate, defaults to 1 - :type num_elems: int, optional - :return: Pointer to the start of the allocated array - :rtype: cute.Pointer - :raises ValueError: If num_elems is less than 1 - :raises TypeError: If element_type is not a Numeric type - """ - if num_elems < 1: - raise ValueError("num_elems must be at least 1") - if not isinstance(element_type, NumericMeta): - raise TypeError( - f"value_ty must be a type of Numeric, but got {element_type}" - ) - - ptr = self.allocate( - element_type.width // 8 * num_elems, element_type.width // 8 - ) - - return cute.recast_ptr(ptr, dtype=element_type) - - def allocate_tensor( - self, - element_type: Type[Numeric], - layout: Union[int, cute.Layout, cute.ComposedLayout], - byte_alignment: int = 1, - swizzle: cute.Swizzle = None, - ): - """Allocate a tensor in shared memory. - - :param element_type: The type of elements in the tensor - :type element_type: Type[Numeric] - :param layout: The layout specification for the tensor - :type layout: Union[int, cute.Layout, cute.ComposedLayout] - :param byte_alignment: The byte alignment requirement, defaults to 1 - :type byte_alignment: int, optional - :param swizzle: Swizzle for position-dependent swizzling, defaults to None - :type swizzle: cute.Swizzle, optional - :return: The allocated tensor with specified properties - :rtype: cute.Tensor - :raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout - :raises ValueError: If allocation is not byte-aligned - :raises NotImplementedError: If dynamic layout is specified - """ - if not isinstance(element_type, NumericMeta): - raise TypeError( - f"value_ty must be a type of Numeric, but got {element_type}" - ) - - if ( - isinstance(layout, cute.ComposedLayout) - and isinstance(layout.inner, cute.Swizzle) - ) and (swizzle is not None): - raise TypeError( - f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." - ) - - if isinstance(layout, int): - layout = cute.make_layout(layout) - - profile = layout(0) - if isinstance(profile, tuple): - raise TypeError( - f"cannot allocate a shared memory tensor with a non-integer iterator" - ) - - if not cute.is_static(layout.type): - raise NotImplementedError(f"dynamic layout is not supported: {layout.type}") - - # At least align the allocation to the natural alignment given by the element type - if element_type.width // 8 > byte_alignment: - byte_alignment = element_type.width // 8 - - # Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned - cosize_in_bits = cute.cosize(layout) * element_type.width - assert isinstance(cosize_in_bits, int) - if cosize_in_bits % 8 != 0: - raise ValueError("invalid allocation that is not byte-aligned") - - num_bytes = cosize_in_bits // 8 - ptr = self.allocate(num_bytes, byte_alignment) - ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type) - res = cute.make_tensor(ptr, layout) - return res diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py deleted file mode 100644 index 87ddb990436caf8135a849b3a37bf52632eed2fc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - - -SMEM_CAPACITY_MAP = { - "sm_120": (100 - 1) * 1024, - "sm_100": (228 - 1) * 1024, - "sm_90": (228 - 1) * 1024, - "sm_80": (164 - 1) * 1024, - "sm_86": (100 - 1) * 1024, - "sm_89": (100 - 1) * 1024, -} - - -def get_smem_capacity_in_bytes(compute_capability: str) -> int: - if compute_capability not in SMEM_CAPACITY_MAP: - raise ValueError(f"Unsupported compute capability: {compute_capability}") - return SMEM_CAPACITY_MAP[compute_capability] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py deleted file mode 100644 index 2873244d7cce9d8072f1fa71bbba1762022631b9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py +++ /dev/null @@ -1,386 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Tuple - -from cutlass.cutlass_dsl import ( - Boolean, - Integer, - Int32, - min, - extract_mlir_values, - new_from_mlir_values, - dsl_user_op, -) -from cutlass._mlir import ir -import cutlass.cute as cute - -############################################################################## -# Static persistent tile scheduler -############################################################################## - - -class WorkTileInfo: - """A class to represent information about a work tile. - - :ivar tile_idx: The index of the tile. - :type tile_idx: cute.Coord - :ivar is_valid_tile: Whether the tile is valid. - :type is_valid_tile: Boolean - """ - - def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean): - self._tile_idx = tile_idx - self._is_valid_tile = Boolean(is_valid_tile) - - def __extract_mlir_values__(self) -> list[ir.Value]: - values = extract_mlir_values(self.tile_idx) - values.extend(extract_mlir_values(self.is_valid_tile)) - return values - - def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": - assert len(values) == 4 - new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1]) - new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]]) - return WorkTileInfo(new_tile_idx, new_is_valid_tile) - - @property - def is_valid_tile(self) -> Boolean: - """Check latest tile returned by the scheduler is valid or not. Any scheduling - requests after all tasks completed will return an invalid tile. - - :return: The validity of the tile. - :rtype: Boolean - """ - return self._is_valid_tile - - @property - def tile_idx(self) -> cute.Coord: - """ - Get the index of the tile. - - :return: The index of the tile. - :rtype: cute.Coord - """ - return self._tile_idx - - -class PersistentTileSchedulerParams: - """A class to represent parameters for a persistent tile scheduler. - - This class is designed to manage and compute the layout of clusters and tiles - in a batched gemm problem. - - :ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1). - :type cluster_shape_mn: tuple - :ivar problem_layout_ncluster_mnl: Layout of the problem in terms of - number of clusters in (m, n, l) dimensions. - :type problem_layout_ncluster_mnl: cute.Layout - """ - - def __init__( - self, - problem_shape_ntile_mnl: cute.Shape, - cluster_shape_mnk: cute.Shape, - *, - loc=None, - ip=None, - ): - """ - Initializes the PersistentTileSchedulerParams with the given parameters. - - :param problem_shape_ntile_mnl: The shape of the problem in terms of - number of CTA (Cooperative Thread Array) in (m, n, l) dimensions. - :type problem_shape_ntile_mnl: cute.Shape - :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. - :type cluster_shape_mnk: cute.Shape - - :raises ValueError: If cluster_shape_k is not 1. - """ - - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") - - self.problem_shape_ntile_mnl = problem_shape_ntile_mnl - # cluster_shape_mnk is kept for reconstruction - self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] - self._loc = loc - - # By default, we follow m major (col-major) raster order, so make a col-major layout - self.problem_layout_ncluster_mnl = cute.make_layout( - cute.ceil_div( - self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]: - obj_values = extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos - ): - obj_list.append(new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) - - @dsl_user_op - def get_grid_shape( - self, max_active_clusters: Int32, *, loc=None, ip=None - ) -> Tuple[Integer, Integer, Integer]: - """ - Computes the grid shape based on the maximum active clusters allowed. - - :param max_active_clusters: The maximum number of active clusters that - can run in one wave. - :type max_active_clusters: Int32 - - :return: A tuple containing the grid shape in (m, n, persistent_clusters). - - m: self.cluster_shape_m. - - n: self.cluster_shape_n. - - persistent_clusters: Number of persistent clusters that can run. - """ - - # Total ctas in problem size - num_ctas_mnl = tuple( - x * y - for x, y in zip( - self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn - ) - ) + (self.problem_layout_ncluster_mnl.shape[2],) - - num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) - - num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip) - # Total ctas that can run in one wave - num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster - - num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave) - num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster - - return (*self.cluster_shape_mn, num_persistent_clusters) - - -class StaticPersistentTileScheduler: - """A scheduler for static persistent tile execution in CUTLASS/CuTe kernels. - - :ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl - :type params: PersistentTileSchedulerParams - :ivar num_persistent_clusters: Number of persistent clusters that can be launched - :type num_persistent_clusters: Int32 - :ivar cta_id_in_cluster: ID of the CTA within its cluster - :type cta_id_in_cluster: cute.Coord - :ivar _num_tiles_executed: Counter for executed tiles - :type _num_tiles_executed: Int32 - :ivar _current_work_linear_idx: Current cluster index - :type _current_work_linear_idx: Int32 - """ - - def __init__( - self, - params: PersistentTileSchedulerParams, - num_persistent_clusters: Int32, - current_work_linear_idx: Int32, - cta_id_in_cluster: cute.Coord, - num_tiles_executed: Int32, - ): - """ - Initializes the StaticPersistentTileScheduler with the given parameters. - - :param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl. - :type params: PersistentTileSchedulerParams - :param num_persistent_clusters: Number of persistent clusters that can be launched. - :type num_persistent_clusters: Int32 - :param current_work_linear_idx: Current cluster index. - :type current_work_linear_idx: Int32 - :param cta_id_in_cluster: ID of the CTA within its cluster. - :type cta_id_in_cluster: cute.Coord - :param num_tiles_executed: Counter for executed tiles. - :type num_tiles_executed: Int32 - """ - self.params = params - self.num_persistent_clusters = num_persistent_clusters - self._current_work_linear_idx = current_work_linear_idx - self.cta_id_in_cluster = cta_id_in_cluster - self._num_tiles_executed = num_tiles_executed - - def __extract_mlir_values__(self) -> list[ir.Value]: - values = extract_mlir_values(self.num_persistent_clusters) - values.extend(extract_mlir_values(self._current_work_linear_idx)) - values.extend(extract_mlir_values(self.cta_id_in_cluster)) - values.extend(extract_mlir_values(self._num_tiles_executed)) - return values - - def __new_from_mlir_values__( - self, values: list[ir.Value] - ) -> "StaticPersistentTileScheduler": - assert len(values) == 6 - new_num_persistent_clusters = new_from_mlir_values( - self.num_persistent_clusters, [values[0]] - ) - new_current_work_linear_idx = new_from_mlir_values( - self._current_work_linear_idx, [values[1]] - ) - new_cta_id_in_cluster = new_from_mlir_values( - self.cta_id_in_cluster, values[2:5] - ) - new_num_tiles_executed = new_from_mlir_values( - self._num_tiles_executed, [values[5]] - ) - return StaticPersistentTileScheduler( - self.params, - new_num_persistent_clusters, - new_current_work_linear_idx, - new_cta_id_in_cluster, - new_num_tiles_executed, - ) - - # called by host - @dsl_user_op - @staticmethod - def create( - params: PersistentTileSchedulerParams, - block_idx: Tuple[Integer, Integer, Integer], - grid_dim: Tuple[Integer, Integer, Integer], - *, - loc=None, - ip=None, - ): - """Initialize the static persistent tile scheduler. - - :param params: Parameters for the persistent - tile scheduler. - :type params: PersistentTileSchedulerParams - :param block_idx: The 3d block index in the format (bidx, bidy, bidz). - :type block_idx: Tuple[Integer, Integer, Integer] - :param grid_dim: The 3d grid dimensions for kernel launch. - :type grid_dim: Tuple[Integer, Integer, Integer] - - :return: A StaticPersistentTileScheduler object. - :rtype: StaticPersistentTileScheduler - """ - params = params - - # Calculate the number of persistent clusters by dividing the total grid size - # by the number of CTAs per cluster - num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( - params.cluster_shape_mn, loc=loc, ip=ip - ) - - bidx, bidy, bidz = block_idx - - # Initialize workload index equals to the cluster index in the grid - current_work_linear_idx = Int32(bidz) - - # CTA id in the cluster - cta_id_in_cluster = ( - Int32(bidx % params.cluster_shape_mn[0]), - Int32(bidy % params.cluster_shape_mn[1]), - Int32(0), - ) - # Initialize number of tiles executed to zero - num_tiles_executed = Int32(0) - return StaticPersistentTileScheduler( - params, - num_persistent_clusters, - current_work_linear_idx, - cta_id_in_cluster, - num_tiles_executed, - ) - - # called by host - @staticmethod - def get_grid_shape( - params: PersistentTileSchedulerParams, - max_active_clusters: Int32, - *, - loc=None, - ip=None, - ) -> Tuple[Integer, Integer, Integer]: - """Calculates the grid shape to be launched on GPU using problem shape, - threadblock shape, and active cluster size. - - :param params: Parameters for grid shape calculation. - :type params: PersistentTileSchedulerParams - :param max_active_clusters: Maximum active clusters allowed. - :type max_active_clusters: Int32 - - :return: The calculated 3d grid shape. - :rtype: Tuple[Integer, Integer, Integer] - """ - - return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) - - # private method - def _get_current_work_for_linear_idx( - self, current_work_linear_idx: Int32, *, loc=None, ip=None - ) -> WorkTileInfo: - """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. - - :param current_work_linear_idx: The linear index of the current work. - :type current_work_linear_idx: Int32 - - :return: An object containing information about the current tile coordinates - and validity status. - :rtype: WorkTileInfo - """ - - is_valid = current_work_linear_idx < cute.size( - self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip - ) - - cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( - current_work_linear_idx, loc=loc, ip=ip - ) - - # cur_tile_coord is a tuple of i32 values - cur_tile_coord = tuple( - Int32(x) * Int32(z) + Int32(y) - for x, y, z in zip( - cur_cluster_coord, - self.cta_id_in_cluster, - (*self.params.cluster_shape_mn, Int32(1)), - ) - ) - - return WorkTileInfo(cur_tile_coord, is_valid) - - @dsl_user_op - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: - return self._get_current_work_for_linear_idx( - self._current_work_linear_idx, loc=loc, ip=ip - ) - - @dsl_user_op - def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: - return self.get_current_work(loc=loc, ip=ip) - - @dsl_user_op - def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): - self._current_work_linear_idx += Int32(advance_count) * Int32( - self.num_persistent_clusters - ) - self._num_tiles_executed += Int32(1) - - @property - def num_tiles_executed(self) -> Int32: - return self._num_tiles_executed - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py deleted file mode 100644 index c6369c200e13ad280dfdecdb5cb4aa7ad081da4c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from dataclasses import dataclass -from enum import Enum, auto -from typing import Tuple - -from cutlass.cutlass_dsl import const_expr - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir - -import cutlass.cute as cute - - -class TensorMapUpdateMode(Enum): - """ - Enum class defining tensor map update modes. - - Modes: - GMEM: Update tensormap in global memory - SMEM: Load tensormap from global memory to shared memory, - update it in shared memory, then store back to global memory - """ - - GMEM = auto() # Update tensormap in global memory - SMEM = auto() # Update tensormap in shared memory - - -@dataclass(frozen=True) -class TensorMapManager: - """ - Manages TensorMap operations including initialization and updates. - Provides utilities to convert tensormap pointer to across different memory spaces. - """ - - tensormap_update_mode: TensorMapUpdateMode - bytes_per_tensormap: int - - # convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap. - # address_space: the address space of the resulting tensormap pointer. It could be generic or gmem - def get_tensormap_ptr( - self, - ptr: cute.Pointer, - address_space=_cute_ir.AddressSpace.gmem, - ) -> cute.Pointer: - if address_space not in [ - _cute_ir.AddressSpace.gmem, - _cute_ir.AddressSpace.generic, - ]: - raise ValueError(f"Invalid address space: {address_space} for tensormap") - - gmem_ptr_i64 = ptr.toint().ir_value() - gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get( - self.bytes_per_tensormap, gmem_ptr_i64.type.width - ) - gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64) - gmem_ptr_ty = _cute_ir.PtrType.get( - _cute_nvgpu_ir.TmaDescriptorTiledType.get(), - address_space, - self.bytes_per_tensormap, - ) - return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align) - - # init tensormap pointed by dst_ptr with the one inside copy_atom. - # dst_ptr should be pointing to a global memory location or a smem location - # warp_id specifies which warp to perform the initialization - @cute.jit - def init_tensormap_from_atom( - self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, warp_id: int - ) -> None: - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - if warp_idx == warp_id: - with cute.arch.elect_one(): - cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr) - cute.arch.sync_warp() - return - - # Perform a fence operation to ensure previous `init_tensormap_from_atom` calls have been completed - def fence_tensormap_initialization( - self, - ) -> None: - if self.tensormap_update_mode == TensorMapUpdateMode.GMEM: - cute.arch.fence_acq_rel_cta() - return - - # Perform a fence operation to ensure previous `update_tensormap` calls have been completed - def fence_tensormap_update( - self, - tensormap_ptr: cute.Pointer, - ) -> None: - cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr) - return - - @cute.jit - def update_tensormap( - self, - tensor_gmem: Tuple[cute.Tensor, ...], - tma_copy_atom: Tuple[cute.CopyAtom, ...], - tensormap_gmem_ptr: Tuple[cute.Pointer, ...], - warp_id: int, - tensormap_smem_ptr: Tuple[cute.Pointer, ...], - ) -> None: - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # updates before touching tensormap in global memory - if warp_idx == warp_id: - if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): - for copy_atom, tensor, smem_ptr in zip( - tma_copy_atom, tensor_gmem, tensormap_smem_ptr - ): - cute.nvgpu.cpasync.update_tma_descriptor( - copy_atom, tensor, smem_ptr - ) - # wait until it's safe to update tensormap in global memory - with cute.arch.elect_one(): - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.sync_warp() - # updates to tensormap in global memory - if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): - for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): - cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) - else: - for copy_atom, tensor, gmem_ptr in zip( - tma_copy_atom, tensor_gmem, tensormap_gmem_ptr - ): - cute.nvgpu.cpasync.update_tma_descriptor( - copy_atom, tensor, gmem_ptr - ) - cute.arch.sync_warp() - cute.nvgpu.cpasync.fence_tma_desc_release() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py deleted file mode 100644 index 06ea3f6f5f54b0b4f125c22504b06f41e8bf7697..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from .cutlass import * - -from ..base_dsl.ast_helpers import ( - loop_selector, - if_selector, - if_executor, - while_selector, - while_executor, - range, - range_constexpr, - range_dynamic, - const_expr, - dynamic_expr, - assert_executor, - bool_cast, - compare_executor, - any_executor, - all_executor, - range_value_check, - range_perf_warning, - cf_symbol_check, - redirect_builtin_function, - copy_members, - get_locals_or_none, -) - -from ..base_dsl import * -from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values -from ..base_dsl.typing import _binary_op_type_promote -from ..base_dsl._mlir_helpers.gpu import * -from ..base_dsl._mlir_helpers.op import dsl_user_op -from ..base_dsl.runtime import * -from ..base_dsl.runtime import cuda as cuda_helpers -from ..base_dsl.compiler import compile -from ..base_dsl.runtime.jit_arg_adapters import * diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py deleted file mode 100644 index 1630c873c7a1be3e013f966ea153c904f2b776ff..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py +++ /dev/null @@ -1,1696 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -""" -This module provides a DSL for Cutlass Dialects. It also includes utils with -regarding to that dialect. -""" - -# Local module imports -from itertools import chain -from types import GenericAlias, SimpleNamespace, UnionType -from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any -import functools -import pkgutil -from dataclasses import is_dataclass, fields -from collections.abc import Sequence -import builtins - -from ..base_dsl import * -from ..base_dsl import compiler -from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values -from ..base_dsl.typing import * -from ..base_dsl.typing import DynamicExpression, get_mlir_types -from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr - -from ..base_dsl.ast_helpers import const_expr - -# MLIR Imports -from cutlass._mlir import ir, execution_engine, passmanager -from cutlass._mlir.dialects import arith, func, gpu, scf, cute, gpu as cutlass_gpu -from cutlass._mlir.dialects._ods_common import ( - get_op_result_or_op_results as _get_op_result_or_op_results, -) -from cutlass._mlir.extras import types as T - -# Helpers -from ..base_dsl._mlir_helpers import arith as cutlass_arith -from ..base_dsl._mlir_helpers import lru_cache_ir - -from ..base_dsl.ast_helpers import ( - loop_selector, - executor, - if_selector, - if_executor, - while_selector, - while_executor, - assert_executor, - const_expr, - dynamic_expr, - bool_cast, - compare_executor, - any_executor, - all_executor, - range_value_check, - range_perf_warning, - cf_symbol_check, -) - -from .cutlass_ast_decorators import ( - _loop_execute_range_dynamic, - _if_execute_dynamic, - _while_execute_dynamic, -) - -from .tree_utils import ( - is_constexpr_field, - tree_flatten, - tree_unflatten, - PyTreeDef, - is_frozen_dataclass, - DSLTreeFlattenError, -) -from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry - - -# ============================================================================= -# Cutlass DSL Base Abstract Class -# ============================================================================= - - -# Return a ctype class that represents the in-memory layout expected -# for a CuTe hierarchical tuple type. -def get_sparse_tuple_ctype(dyn): - # When there is a single dynamic value, the sparse CuTe - # representation is a single integer. - if isinstance(dyn, int): - return ctypes.c_int32 - - # For zero or greater than 1 dynamic values, the tuple - # representation will be a struct with a field for each dynamic - # value. The representation is flattened, even for hierarchical CuTe - # profiles (although we are only dealing with depth 1 inputs here). - class TupleDescriptor(ctypes.Structure): - _fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn))] - - def __str__(self): - return f"struct<{str(self._fields_)}>" - - return TupleDescriptor - - -def is_cute_algebra_type(arg_spec): - # Walk through the arg_spec to check if it's a cute algebra type - _cute_algebra_type_aliases = ( - "Shape", - "Stride", - "Coord", - "Tile", - "IntTuple", - ) - - origin = get_origin(arg_spec) - if origin is Union: - for sub_ty in get_args(arg_spec): - sub_origin = get_origin(sub_ty) - if sub_origin is Tuple or ( - type(sub_origin) is type and issubclass(sub_origin, tuple) - ): - tuple_arg0 = get_args(sub_ty)[0] - if isinstance( - tuple_arg0, ForwardRef - ) and tuple_arg0.__forward_arg__ in (_cute_algebra_type_aliases): - return True - return False - - -def _get_c_pointers_cutlass(obj): - """ - This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. - """ - if hasattr(obj, "__c_pointers__"): - return obj.__c_pointers__() - elif isinstance(obj, (tuple, list)): - return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj)) - elif isinstance(obj, SimpleNamespace): - return list( - chain.from_iterable( - _get_c_pointers_cutlass(x) for x in obj.__dict__.values() - ) - ) - elif isinstance(obj, dict): - return list( - chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values()) - ) - elif is_dataclass(obj): - return list( - chain.from_iterable( - _get_c_pointers_cutlass(getattr(obj, f.name)) - for f in fields(obj) - if not is_constexpr_field(f) - ) - ) - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in get_c_pointers to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - else: - # Try get adapter - adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj)) - if adapter is not None: - return _get_c_pointers_cutlass(adapter(obj)) - return [] - - -class CutlassBaseDSL(BaseDSL): - """This abstract class provides a DSL for Cutlass.""" - - def __init__( - self, - name: str, - compiler_provider: Any, - pass_sm_arch_name: str, - device_compilation_only: bool = False, - preprocess: bool = False, - ): - super().__init__( - name=name, - dsl_package_name=["cutlass"], - compiler_provider=compiler_provider, - pass_sm_arch_name=pass_sm_arch_name, - device_compilation_only=device_compilation_only, - preprocess=preprocess, - ) - self._smem_usage_tracker: tuple = None - - # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. - def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: - return False - - # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. - def _handle_tensor_descriptor( - self, maybe_tensor, arg_name: str, need_gpu_memory: bool - ) -> Any: - return False - - def _build_gpu_module(self, attrs): - self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels")) - with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])): - pass - - for attr_name in attrs: - self.gpu_module.attributes[attr_name] = ir.Attribute.parse(attrs[attr_name]) - - def _get_pipeline(self, pipeline): - pipeline = super()._get_pipeline(pipeline) - if pipeline == None: - # cubin format is required to be cubin as we launch cuda module at python level. - return ( - "builtin.module(cute-to-nvvm{cubin-format=bin " - + self.compile_options.to_str() - + "})" - ) - - return pipeline - - def preprocess_pipeline(self, pipeline, arch) -> str: - pipeline = super().preprocess_pipeline(pipeline, arch) - pipeline = pipeline.rstrip(")") + ",external-kernel-for-gpu-launch)" - return pipeline - - def _enter_gpu_module(self): - return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0]) - - def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: - assert isinstance( - config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(config)}" - - ret = {} - # generate launch bound attr from LaunchConfig - max_threads = ", ".join(map(str, config.block)) - ret["nvvm.reqntid"] = ir.Attribute.parse(f"array") - # min_blocks_per_mp is optional for kernel - min_blocks = config.min_blocks_per_mp - if min_blocks > 0: - ret["nvvm.minctasm"] = ir.Attribute.parse(f"{min_blocks} : i32") - return ret - - @lru_cache(maxsize=1) - def get_version(self): - """ - Get the version of cutlass dsl, used for computing the hash key of the cache. - Including source python files and the shared library. - """ - dsl_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - # get the version hash of the cutlass shared library - version_hash = hashlib.sha256() - # update the version hash of the source python files - for lib in pkgutil.walk_packages([dsl_path], prefix="cutlass."): - try: - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - version_hash.update(f.read()) - except Exception: - raise DSLRuntimeError( - f"Failed to read module file {lib.name}. The file may not exist or may not be readable." - "Please re-install the package." - ) - try: - # update the version hash of the cutlass shared library - with open( - os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"), - "rb", - ) as f: - while True: - chunk = f.read(1024**2) - if not chunk: - break - version_hash.update(chunk) - except Exception: - raise DSLRuntimeError( - f"Failed to read the shared library file libCutlassIRPythonCAPI.so." - "The file may not exist or may not be readable." - "Please re-install the package." - ) - - return version_hash - - @staticmethod - def track_smem_allocator(allocator, callback): - """ - Tracks shared memory usage for kernel functions. - Find and set allocator to its parent dsl object. - """ - frame = inspect.currentframe().f_back - while frame: - obj = frame.f_locals.get("self", None) - if obj and isinstance(obj, CutlassBaseDSL): - obj._set_smem_tracking(allocator, callback) - return - frame = frame.f_back - warnings.warn("Cannot find parent dsl for allocator!", UserWarning) - - def _set_smem_tracking(self, allocator, callback): - # Registers an allocator and callback for current dsl - self._smem_usage_tracker = (allocator, callback) - - def _reset_smem_tracking(self): - # Clear an allocator and callback for current dsl - self._smem_usage_tracker = None - - def _get_smem_usage(self) -> int: - # Treat final allocated bytes of allocator as smem usage - if not self._smem_usage_tracker: - return 0 - allocator, callback = self._smem_usage_tracker - return callback(allocator) - - def _kernel_helper(self, funcBody, *args, **kwargs): - class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): - def __init__(self, dsl: CutlassBaseDSL): - super().__init__() - self.dsl = dsl - self.dsl._reset_smem_tracking() - - def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): - super().generate_func_op(arg_types, arg_attrs, kernel_name) - self.func_op = func.FuncOp( - kernel_name, ir.FunctionType.get(arg_types, []), loc=loc - ) - if arg_attrs is not None: - log().debug(arg_attrs) - self.func_op.arg_attrs = arg_attrs - return self.func_op - - def generate_func_ret_op(self): - return func.ReturnOp([]) - - def get_func_body_start(self): - assert self.func_op is not None, "Invalid func_op is not expected!" - return self.func_op.add_entry_block() - - def generate_launch_op(self, *args, **kwargs): - # Extract args and do validation - kernelSym = kwargs.get("kernelSym", None) - kernelOperands = kwargs.get("kernelOperands", None) - requiredArgs = kwargs.get("requiredArgs", None) - assert kernelSym is not None, "kernelSym being None is not expected!" - assert ( - requiredArgs is not None - ), "requiredArgs being None is not expected!" - assert ( - kernelOperands is not None - ), "kernelOperands being None is not expected!" - assert isinstance( - requiredArgs.config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" - - cfg = requiredArgs.config - - # Apply to grid, block, and cluster if present - cfg.grid = [to_index(size) for size in cfg.grid] - cfg.block = [to_index(size) for size in cfg.block] - if cfg.has_cluster: - cfg.cluster = [to_index(size) for size in cfg.cluster] - - smem_usage = self.dsl._get_smem_usage() - if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]): - pass # cannot compare dynamic value inside kernel to launch op in py - elif cfg.auto_smem: - cfg.smem = smem_usage - elif smem_usage > cfg.smem: - warnings.warn( - f"Potential error: specified kernel launch smem bytes " - f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!", - UserWarning, - ) - cfg.smem = const(cfg.smem) - - if not isinstance(cfg.async_deps, (list, tuple)): - cfg.async_deps = [cfg.async_deps] - is_async = len(cfg.async_deps) > 0 - token = gpu.launch_func( - gpu.AsyncTokenType.get() if is_async else None, - cfg.async_deps, - kernelSym, - *cfg.grid, - *cfg.block, - kernelOperands, - **dict( - zip( - ("cluster_size_x", "cluster_size_y", "cluster_size_z"), - tuple(cfg.cluster), - ) - ), - dynamic_shared_memory_size=cfg.smem, - ) - return token if is_async else None - - return KernelLauncher( - self, - lambda: _CutlassIrKernelGenHelper(self), - funcBody, - *args, - **kwargs, - ) - - def _preprocess_launch_config_args(self, args, kwargs): - """Helper to preprocess args and kwargs for LaunchConfig""" - if "stream" in kwargs: - kwargs["async_deps"] = kwargs.pop("stream") - - def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): - """Mangle the name of the function to avoid conflicts with other functions""" - function_name = "cutlass_" + function_name - return super().mangle_name(function_name, args, args_spec) - - def _validate_arg(self, arg, arg_index, arg_name, arg_annotation): - """ - Validates if the arg is really of the annotated type. - """ - - if ( - is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None) - or arg_annotation is Any - ): - pass - else: - origin = get_origin(arg_annotation) - # Handle special case where annotation is Type[X] but arg is an actual type - if origin is type and isinstance(arg, type): - # Get the expected base type from Type[X] - expected_base = get_args(arg_annotation)[0] - if not issubclass(arg, expected_base): - return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}" - ) - # Handle Union types and generic types - elif origin is Union or isinstance(arg_annotation, UnionType): - # For Union types, check if arg matches any of the allowed types - allowed_types = get_args(arg_annotation) - if not any( - (ty is Any) - or (isinstance(ty, type) and isinstance(arg, ty)) - or (get_origin(ty) is tuple and isinstance(arg, tuple)) - for ty in allowed_types - ): - return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}" - ) - elif isinstance(arg_annotation, type): - # Handle simple type annotations - if not isinstance(arg, arg_annotation) and arg is not None: - return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}" - ) - # Everything looks good if we are here - return None - - def _generate_jit_func_args_for_known_types( - self, - func, - arg, - arg_name, - arg_spec, - arg_index, - *, - is_host=True, - ): - jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] - default_attr = ir.DictAttr.get({}) - - ( - jit_exec_arg, - jit_arg_type, - jit_arg_attr, - ) = super()._generate_jit_func_args_for_known_types( - func, arg, arg_name, arg_spec, arg_index, is_host=is_host - ) - - if jit_arg_type is not None and len(jit_arg_type) == 0: - # Handle DSL specific types - if is_cute_algebra_type(arg_spec): - dyn_vals = extract_mlir_values(arg) - if dyn_vals: - # Handle dynamic types - jit_arg_type.extend([v.type for v in dyn_vals]) - jit_arg_attr.extend([default_attr] * len(dyn_vals)) - jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals) - else: - jit_exec_arg = jit_arg_type = jit_arg_attr = None - elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( - arg, "__new_from_mlir_values__" - ): - # Try tree_flatten - try: - dyn_vals, _ = tree_flatten(arg) - except DSLTreeFlattenError: - # If fails, just return the original arg - return jit_exec_arg, jit_arg_type, jit_arg_attr - - if dyn_vals: - jit_arg_type.extend([v.type for v in dyn_vals]) - jit_arg_attr.extend([default_attr] * len(dyn_vals)) - jit_exec_arg.extend( - _get_c_pointers_cutlass(arg) if is_host else dyn_vals - ) - else: - # If tree flatten yields empty list, treat it as a constexpr thing - # Like a dataclass with all fields are constexpr, or an empty tuple or list - jit_exec_arg = jit_arg_type = jit_arg_attr = None - return jit_exec_arg, jit_arg_type, jit_arg_attr - - def _generate_execution_arguments_for_known_types( - self, arg, arg_spec, arg_name, i, fop_args, iv_block_args - ): - ir_arg, iv_block_args = super()._generate_execution_arguments_for_known_types( - arg, arg_spec, arg_name, i, fop_args, iv_block_args - ) - if not ir_arg: - # Handling DSL specific types - if is_cute_algebra_type(arg_spec): - n_args = len(get_mlir_types(arg)) - blk_args = fop_args[iv_block_args : iv_block_args + n_args] - ir_arg.append(new_from_mlir_values(arg, blk_args)) - iv_block_args += n_args - elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( - arg, "__new_from_mlir_values__" - ): - # Try tree_unflatten - try: - dyn_vals, tree_def = tree_flatten(arg) - block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)] - ir_arg.append(tree_unflatten(tree_def, block_args)) - iv_block_args += len(dyn_vals) - except DSLTreeFlattenError: - return ir_arg, iv_block_args - - return ir_arg, iv_block_args - - -# ============================================================================= -# Cute DSL Class -# ============================================================================= - - -class CuTeDSL(CutlassBaseDSL): - """ - This is a concrete DSL subclass for the CuTe dialect. - """ - - def __init__(self): - name = "CUTE_DSL" - compiler_provider = compiler.Compiler(passmanager, execution_engine) - pass_sm_arch_name = "cubin-chip" - - super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True) - - -# ============================================================================= -# KernelLauncher -# ============================================================================= - - -class KernelLauncher: - """ - This class is used to launch a kernel function. - Usage: - ```python - @cute.kernel - def kernel(arg1, arg2, ...): - ... - - @cute.jit - def launch_kernel(): - kernel(arg1, arg2, ...).launch(grid=[1, 1, 1], block=[1, 1, 1], ...) - # or - kernel(arg1, arg2, ...)(grid=[1, 1, 1], block=[1, 1, 1], ...) - ``` - """ - - def __init__( - self, - dsl: "CutlassBaseDSL", - kernelGenHelper: BaseDSL._KernelGenHelper, - funcBody, - *func_args, - **func_kwargs, - ): - self.dsl = dsl - self.kernelGenHelper = kernelGenHelper - self.funcBody = funcBody - self.func_args = func_args - self.func_kwargs = func_kwargs - - self._check_func_args(funcBody, *func_args, **func_kwargs) - - def _check_func_args(self, funcBody, *func_args, **func_kwargs): - # Get function signature - sig = inspect.signature(funcBody) - - # func_args and func_kwargs should match funcBody's signature, - # no extra or missing arguments. - try: - sig.bind(*func_args, **func_kwargs) - except TypeError as e: - raise DSLRuntimeError( - f"Failed to bind arguments to function `{funcBody.__name__}` with signature `{sig}`", - cause=e, - ) - - def smem_usage(self) -> int: - """ - Check smem usage for this kernel, only available after `launch` - """ - return self.dsl._get_smem_usage() - - def launch(self, *args, **kwargs): - self.dsl.frame = inspect.currentframe().f_back - self.dsl._preprocess_launch_config_args(args, kwargs) - config = self.dsl.LaunchConfig(*args, **kwargs) - - kernel_generator = self.dsl.kernel_launcher( - requiredArgs=["config"], - unitAttrNames=["gpu.kernel", "cute.kernel"], - valueAttrDict=self.dsl._generate_kernel_attrs(config), - kernelGenHelper=self.kernelGenHelper, - )(self.funcBody) - - ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) - self.dsl.kernel_symbols.append(name) - self.dsl.frame = None - return ret.launch_op_ret - - def __call__(self, *args, **kwargs): - return self.launch(*args, **kwargs) - - -# ============================================================================= -# Utils -# ============================================================================= -def _filter_readonly_frozen_dataclass( - iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int -) -> List[Any]: - """ - Filter items based on whether corresponding iter_args are frozen dataclasses. - - This function filters items (which can be values or names) based on the same - logic: keep items if they correspond to full-write arguments (index < full_write_args_count) - or if the corresponding iter_arg is not a frozen dataclass. - - Args: - iter_args: List of arguments to check for frozen dataclass status - items_to_filter: List of items to filter (values or names) - full_write_args_count: Number of arguments that are always written (not read-only) - - Returns: - Filtered list of items - - Examples: - # Filter values (original remove_read_only_frozen_dataclass behavior) - filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count) - - # Filter names (original filter_readonly_frozen_dataclass_names behavior) - filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count) - """ - return [ - item - for i, item in enumerate(items_to_filter) - if i < full_write_args_count or not is_frozen_dataclass(iter_args[i]) - ] - - -def remove_read_only_frozen_dataclass( - iter_args: List[Any], full_write_args_count: int -) -> List[Any]: - """Filter out frozen dataclass arguments that are not full-write arguments.""" - return _filter_readonly_frozen_dataclass( - iter_args, iter_args, full_write_args_count - ) - - -def filter_readonly_frozen_dataclass_names( - iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int -) -> List[str]: - """Filter names based on whether corresponding iter_args are frozen dataclasses.""" - return _filter_readonly_frozen_dataclass( - iter_args, iter_args_names, full_write_args_count - ) - - -def insert_read_only_frozen_dataclass( - iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int -) -> List[Any]: - """ - Insert read-only frozen dataclass arguments back into the iteration arguments. - - This function takes the new iteration arguments and the original arguments, - and preserves frozen dataclass instances from the original arguments while - using the new arguments for non-frozen dataclass instances. - - Args: - iter_args: New iteration arguments to use for non-frozen dataclass instances - original_iter_args: Original iteration arguments to preserve frozen dataclass instances from - full_write_args_count: Number of arguments that are always written (not read-only) - - Returns: - List of arguments with frozen dataclass instances preserved from original - """ - # Take full-write arguments from new iter_args - full_write_args = ( - iter_args[:full_write_args_count] if full_write_args_count > 0 else [] - ) - - # Process remaining arguments: preserve frozen dataclass from original, use new for others - remaining_original = original_iter_args[full_write_args_count:] - remaining_new = iter_args[full_write_args_count:] - - def process_remaining_arg(original_arg, new_arg_iter): - """Process a single remaining argument, preserving frozen dataclass if present""" - return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter) - - # Use zip to pair original args with new args, then map the processing function - new_arg_iter = iter(remaining_new) - processed_remaining = [ - process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original - ] - - return full_write_args + processed_remaining - - -def unpack_to_irvalue( - mixed_values: List[Any], body_name: str, full_write_args_count: int -) -> Tuple[List[ir.Value], PyTreeDef]: - log().debug("===--- Values UNPack") - for idx, packed in enumerate(mixed_values): - log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) - - try: - unpacked_values, treedef = tree_flatten( - remove_read_only_frozen_dataclass(mixed_values, full_write_args_count) - ) - except DSLTreeFlattenError as e: - raise DSLRuntimeError( - f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.", - context={ - e.message: ( - f"All expressions within '{body_name}' must be dynamic expressions, " - "mixing Python objects and dynamic expressions is not supported. " - "The DSL failed to convert the Python object into dynamic expressions." - ) - }, - suggestion=( - f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, " - f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects." - ), - ) - - log().debug("------------------ ") - for idx, unpacked in enumerate(unpacked_values): - log().debug("[%d]: unpacked values: %s", idx, unpacked) - log().debug("treedef: %s", treedef) - log().debug("------------------ ") - - return unpacked_values, treedef - - -def pack_from_irvalue( - ir_values: List["ir.Value"], - pytree_def: PyTreeDef, - mixed_values: List[Any], - full_write_args_count: int, -) -> List[Any]: - """ - Packs MLIR values into a list of mixed values. - """ - log().debug("===--- Values Pack (%d)", len(ir_values)) - for idx, value in enumerate(ir_values): - log().debug("[%d]: will-packed: %s", idx, value) - log().debug("treedef: %s", pytree_def) - log().debug("------------------ ") - - unflattened = tree_unflatten(pytree_def, ir_values) - return insert_read_only_frozen_dataclass( - unflattened, mixed_values, full_write_args_count - ) - - -def to_index(value): - """Converts a value to an index, either by casting or coercing to int.""" - if is_dynamic_expression(value): - if isinstance(value, Numeric): - value = value.ir_value() - assert ir.IntegerType.isinstance( - value.type - ), f"expects integer type, but got {value.type}" - res = arith.index_cast(T.index(), value) - else: - res = const(int(value), ty=T.index()) - - return res - - -def _validate_iter_args_structure(iter_args, ir_values): - """ - Validates that iter_args structure contains the same number of atomic values - as there are IR values. - - Args: - iter_args: Original iteration arguments, possibly nested sequences - ir_values: Flattened MLIR values extracted from iter_args - - Returns: - bool: True if the number of atomic values in iter_args matches - the number of values in ir_values - """ - # Handle non-sequence case - if not isinstance(iter_args, (tuple, list, set)): - return not isinstance(ir_values, (tuple, list, set)) or len(ir_values) == 1 - - # If we have a sequence but ir_values isn't one, there's a mismatch - if not isinstance(ir_values, (tuple, list, set)): - return False - - # Count all non-sequence values recursively - def count_values(args): - if not isinstance(args, (tuple, list, set)): - return 1 - else: - return sum(count_values(arg) for arg in args) - - return count_values(iter_args) == len(ir_values) - - - -# ============================================================================= -# DSL implementation of Python Build-in Operators -# ============================================================================= - - -def _minmax(op, *args, loc=None, ip=None): - """Computes the minimum or maximum value from the provided arguments.""" - from ..base_dsl.typing import _binary_op, _binary_op_type_promote - - # AST Traversal doesn't support early exit in if executor - x = None - res = None - if len(args) == 1: - # Handle case for min([a, b, c, d, ..]) - if hasattr(args[0], "__iter__"): - x = op(*tuple(args[0])) - # Handle case for min(a) - else: - x = args[0] - # Handle case for min(a, b, c, ...) and min([x, y], [b]) and min(a, (x, y, z)) - elif len(args) > 1: - res, *xs = tuple(args) - for x in xs: - lhs = as_numeric(op(res, loc=loc, ip=ip)) - rhs = as_numeric(op(x, loc=loc, ip=ip)) - emitter = getattr(cutlass_arith, f"_{op.__name__}") - - lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool=True) - - if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( - lhs, Integer - ): - lhs_val = lhs.value.with_signedness(lhs.signed) - else: - lhs_val = lhs.value - - if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( - rhs, Integer - ): - rhs_val = rhs.value.with_signedness(rhs.signed) - else: - rhs_val = rhs.value - - res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) - x = res - else: - raise DSLNotImplemented(f"{type(args)} is not supported") - return x - - -def min(*args, loc=None, ip=None): - """Computes the minimum value from the provided arguments. - - This function differs from Python's built-in min() in that the return type - is determined by the static types of the inputs, not their dynamic values. - - :param args: One or more values or iterables to find the minimum of - :type args: tuple - :param loc: Source location for MLIR operation tracking - :type loc: object, optional - :param ip: Insertion point for MLIR operation - :type ip: object, optional - :return: The minimum value among all inputs - :rtype: Numeric - :raises DSLNotImplemented: If the input type is not supported - - Supports multiple calling patterns: - - - min(a): Returns a - - min([a, b, c, ...]): Returns minimum of all elements in the iterable - - min(a, b, c, ...): Returns minimum of all arguments - - min([x, y], [b]): Returns minimum across all elements in all iterables - - min(a, (x, y, z)): Returns minimum across all elements - - Examples: - - .. code-block:: python - - # Find minimum of two values - result = min(x, y) - - # Find minimum of multiple values - result = min(a, b, c, d) - - # Find minimum of values in a list - values = [a, b, c, d] - result = min(values) - - # Find minimum across mixed arguments - result = min(x, [y, z]) - - Difference from Python's built-in min(): - - .. code-block:: python - - # In Python, the return type depends on the dynamic values: - a = 5 - b = 3.14 - result = min(a, b) # Returns 3.14 (float) - - # In this DSL implementation, the return type is determined statically: - a = Int32(5) - b = Float32(3.14) - result = min(a, b) # Return type is determined by the type of operands, not values - """ - return _minmax(min, *args, loc=loc, ip=ip) - - -def max(*args, loc=None, ip=None): - """Computes the maximum value from the provided arguments. - - This function differs from Python's built-in max() in that the return type - is determined by the static types of the inputs, not their dynamic values. - - :param args: One or more values or iterables to find the maximum of - :type args: tuple - :param loc: Source location for MLIR operation tracking - :type loc: object, optional - :param ip: Insertion point for MLIR operation - :type ip: object, optional - :return: The maximum value among all inputs - :rtype: Numeric - :raises DSLNotImplemented: If the input type is not supported - - Supports multiple calling patterns: - - - max(a): Returns a - - max([a, b, c, ...]): Returns maximum of all elements in the iterable - - max(a, b, c, ...): Returns maximum of all arguments - - max([x, y], [b]): Returns maximum across all elements in all iterables - - max(a, (x, y, z)): Returns maximum across all elements - - Examples: - - .. code-block:: python - - # Find maximum of two values - result = max(x, y) - - # Find maximum of multiple values - result = max(a, b, c, d) - - # Find maximum of values in a list - values = [a, b, c, d] - result = max(values) - - # Find maximum across mixed arguments - result = max(x, [y, z]) - - Difference from Python's built-in max(): - - .. code-block:: python - - # In Python, the return type depends on the dynamic values: - a = 5 - b = 3.14 - result = max(a, b) # Returns 5 (int) - - # In this DSL implementation, the return type is determined statically: - a = Int32(5) - b = Float32(3.14) - result = max(a, b) # Return type is determined by the type of operands, not values - """ - return _minmax(max, *args, loc=loc, ip=ip) - - -def and_(*args, loc=None, ip=None): - """AND operation for value in DSL numeric types. - - :param *args: One or more numeric values to AND together - :type *args: Numeric - :param loc: Source location for MLIR operation tracking - :type loc: object, optional - :param ip: Insertion point for MLIR operation - :type ip: object, optional - :return: The result of the logical AND operation - :rtype: Numeric - :raises ValueError: If no arguments are provided - - Supports multiple calling patterns: - - - and_(a): Returns a - - and_(a, b, c, ...): if a is truthy, returns and_(b, c, ...), otherwise returns a - - All arguments must be of the same type. - - Examples: - - .. code-block:: python - - # In Python, 'and' returns the second operand if the first is truthy, - # otherwise it returns the first operand - a = 5 - b = 3 - result = a and b # Returns 3 - - # In this DSL implementation, the behavior is similar but works with DSL types - a = Int32(5) - b = Int32(3) - result = and_(a, b) # Returns b - """ - if len(args) == 0: - raise ValueError("and_() requires at least one argument") - - if len(args) == 1: - return args[0] - - def and_op(lhs, rhs): - if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): - raise DSLNotImplemented(f"{type(lhs)} is not supported") - elif isinstance(lhs, (int, float, bool)) and isinstance( - rhs, (int, float, bool) - ): - return lhs and rhs - else: - return as_numeric(lhs).__dsl_and__(as_numeric(rhs)) - - return functools.reduce(and_op, args[1:], args[0]) - - -def or_(*args, loc=None, ip=None): - """Logical OR operation for DSL numeric types. - - :param *args: One or more numeric values to OR together - :type *args: Numeric - :param loc: Source location for MLIR operation tracking - :type loc: object, optional - :param ip: Insertion point for MLIR operation - :type ip: object, optional - :return: The result of the logical OR operation - :rtype: Numeric - :raises ValueError: If no arguments are provided - - Supports multiple calling patterns: - - - or_(a): Returns a - - or_(a, b, c, ...): if a is truthy, returns a, otherwise returns or_(b, c, ...) - - Examples: - - .. code-block:: python - - # In Python, 'or' returns the first operand if it's truthy, - # otherwise it returns the second operand - a = 5 - b = 3 - result = a or b # Returns 5 - - # In this DSL implementation, the behavior is similar but works with DSL types - a = Int32(5) - b = Int32(3) - result = or_(a, b) # Returns a - """ - if len(args) == 0: - raise ValueError("or_() requires at least one argument") - - if len(args) == 1: - return args[0] - - def or_op(lhs, rhs): - if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): - raise DSLNotImplemented(f"{type(lhs)} is not supported") - elif isinstance(lhs, (int, float, bool)) and isinstance( - rhs, (int, float, bool) - ): - return lhs or rhs - else: - return as_numeric(lhs).__dsl_or__(as_numeric(rhs)) - - return functools.reduce(or_op, args[1:], args[0]) - - -def all_(iterable): - """Logical AND operation for all elements in an iterable. - - Returns True if all elements in the iterable are truthy, otherwise False. - This is the DSL equivalent of Python's built-in all() function. - - :param iterable: An iterable containing values to check - :type iterable: Iterable - :return: True if all elements are truthy, False otherwise - :rtype: Boolean - - Examples: - - .. code-block:: python - - # Check if all values are non-zero - values = [Int32(1), Int32(2), Int32(3)] - result = all_(values) # Returns True - - # Check if all conditions are met - conditions = [a > 0, b < 10, c != 0] - result = all_(conditions) # Returns True if all conditions are met - """ - bool_iterable = [Boolean(i) for i in iterable] - return functools.reduce( - lambda lhs, rhs: lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs, - bool_iterable, - Boolean(True), - ) - - -def any_(iterable): - """Logical OR operation for any element in an iterable. - - Returns True if any element in the iterable is truthy, otherwise False. - This is the DSL equivalent of Python's built-in any() function. - - :param iterable: An iterable containing values to check - :type iterable: Iterable - :return: True if any element is truthy, False otherwise - :rtype: Boolean - - Examples: - - .. code-block:: python - - # Check if any value is non-zero - values = [Int32(0), Int32(0), Int32(3)] - result = any_(values) # Returns True - - # Check if any condition is met - conditions = [a > 10, b < 0, c != 0] - result = any_(conditions) # Returns True if any condition is met - """ - bool_iterable = [Boolean(i) for i in iterable] - return functools.reduce( - lambda lhs, rhs: lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs, - bool_iterable, - Boolean(False), - ) - - -# ============================================================================= -# Conditional Expression -# ============================================================================= - - -def select_(cond, if_value, else_value): - def _as_scalar(value): - if isinstance(value, list): - if len(value) == 1: - return value[0] - else: - raise DSLRuntimeError( - "Conditional expression must have exactly one value in all expressions" - ) - return value - - if not is_dynamic_expression(cond): - raise DSLRuntimeError("Conditional expression must be dynamic") - - # Extract MLIR values - cond = extract_mlir_values(cond) - if is_dynamic_expression(if_value): - if_value = extract_mlir_values(if_value) - else: - if_value = const(if_value) - if is_dynamic_expression(else_value): - else_value = extract_mlir_values(else_value) - else: - else_value = const(else_value) - - return arith.SelectOp( - _as_scalar(cond), _as_scalar(if_value), _as_scalar(else_value) - ).result - - -# ============================================================================= -# Terminator -# ============================================================================= - - -def yield_out(args=[], loc=None, ip=None): - """ - Generate a yield operation. It it used to return values from a loop, if-else, or while region. - """ - scf.yield_(extract_mlir_values(args), loc=loc, ip=ip) - - -# ============================================================================= -# For Loop -# ============================================================================= - - -class LoopUnroll(ir.Attribute): - def __init__(self, **kwargs): - valid_keys = set(["count", "full"]) - def to_mlir_attr(val): - if isinstance(val, bool): - return "true" if val else "false" - elif isinstance(val, int): - return f"{val} : i32" - else: - raise DSLNotImplemented(f"{type(val)} is not supported") - - cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} - if kwargs.get("count", None) == 1: - cfg["disable"] = "true" - - unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" - - super().__init__( - ir.Attribute.parse(f"#llvm.loop_annotation") - ) - - -def for_generate( - start, - stop=None, - step=None, - iter_args: Optional[Sequence[ir.Value]] = None, - *, - unroll: LoopUnroll = None, - prefetch_stages=None, - loc=None, - ip=None, -): - """ - scf.for with yield support - """ - - if step is None: - step = 1 - if stop is None: - stop = start - start = 0 - start = const(start) - params = [start, stop, step] - for i, p in enumerate(params): - if isinstance(p, int): - p = const(p) - elif isinstance(p, float): - raise DSLRuntimeError(f"{p=} must be int.") - elif isinstance(p, Integer): - p = p.ir_value() - params[i] = p - - start, stop, step = params - - def _createI32Attr(value): - if not isinstance(value, int): - raise DSLRuntimeError(f"value must be int.") - return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) - - ir_iter_args = extract_mlir_values(iter_args) if iter_args is not None else None - if not _validate_iter_args_structure(iter_args, ir_iter_args): - raise DSLRuntimeError("iter_args: Elements should be extractable as ir.Value.") - for_op = scf.ForOp(start, stop, step, ir_iter_args, loc=loc, ip=ip) - if unroll is not None: - for_op.attributes["loop_annotation"] = unroll - - if prefetch_stages is not None: - for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages) - - iv = for_op.induction_variable - new_results = new_from_mlir_values(iter_args, for_op.results) - new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args) - new_iter_args = () if new_iter_args is None else tuple(new_iter_args) - - with ir.InsertionPoint(for_op.body): - if len(new_iter_args) > 1: - yield iv, new_iter_args, new_results - elif len(new_iter_args) == 1: - yield iv, new_iter_args[0], new_results[0] - else: - yield iv - - -# ============================================================================= -# Logical Operators -# ============================================================================= - - -def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): - """ - Logical Not - """ - res = None - # Handle Python bool first to prevent infinite recursion - if type(lhs) == bool: - res = lhs ^ True - elif hasattr(lhs, "__dsl_not__"): - res = lhs.__dsl_not__(loc=loc, ip=ip) - elif is_dynamic_expression(lhs): - # If lhs is MLIR value, compute not using xor - res = arith.XOrIOp(lhs, const(1, lhs.type)).result - else: - res = bool(lhs) ^ True - - return res - - -# ============================================================================= -# If/Else -# ============================================================================= - - -def if_generate( - cond: Boolean, - then_body: Callable, - else_body: Optional[Callable] = None, - input_args: List[DslType] = None, - return_types: List[DslType] = None, - *, - loc=None, - ip=None, -) -> List: - """ - Generate an IfOp with optional else branch and return values. - - Args: - cond: The condition expression - then_body: Function to execute in then branch - else_body: Optional function to execute in else branch - input_args: Arguments to pass to branch bodies - return_types: Expected return types for the operation - loc: Optional location information - ip: Optional insertion point - - Returns: - List of DSL typed results - """ - input_args = input_args or [] - mlir_return_types = [] - - # Validate and collect MLIR return types (if provided). - if return_types is not None: - for t in return_types: - if not isinstance(t, DslType): - raise DSLRuntimeError(f"{t=} must be a DslType.") - mlir_return_types.append(t.mlir_type) - - # Determine whether there's an else branch. - has_else = else_body is not None - - # Create the IfOp. - if_op = scf.IfOp( - Boolean(cond).ir_value(), mlir_return_types, hasElse=has_else, loc=loc, ip=ip - ) - - def _execute_and_yield_out(body, input_args): - yield_vals = body(*input_args) - if return_types is not None: - if not isinstance(yield_vals, Iterable): - # body only return single element - yield_vals = [yield_vals] - - yield_vals = [t(r) for t, r in zip(return_types, yield_vals)] - yield_out(yield_vals) - - # Generate the body for 'then'. - with ir.InsertionPoint(if_op.then_block): - _execute_and_yield_out(then_body, input_args) - - # Generate the body for 'else' if provided. - if has_else: - with ir.InsertionPoint(if_op.else_block): - _execute_and_yield_out(else_body, input_args) - - # Collect MLIR results. - mlir_results = _get_op_result_or_op_results(if_op) - - if not isinstance(mlir_results, list): - mlir_results = [mlir_results] - - # Wrap the results with their DSL types. - if return_types is None: - return [] - - vals = [t(r) for t, r in zip(return_types, mlir_results)] - - if len(vals) == 1: - return vals[0] - - return vals - - -# ============================================================================= -# While Loop -# ============================================================================= - - -class WhileLoopContext: - """ - Context manager for a dynamic while loop. - """ - - def __init__( - self, - inputs: Sequence[Union[ir.Value, Numeric]], - condition: Callable[[Sequence[ir.Value]], ir.Value], - *, - loc=None, - ip=None, - ): - # Keep original inputs and allow recover original type information - self.inputs = inputs - - self.input_ir_values = extract_mlir_values(inputs) - - if not _validate_iter_args_structure(inputs, self.input_ir_values): - raise DSLRuntimeError("inputs: Elements should be extractable as ir.Value.") - - self.condition = condition - self.input_ir_types = [i.type for i in self.input_ir_values] - self.while_op = scf.WhileOp( - self.input_ir_types, self.input_ir_values, loc=loc, ip=ip - ) - - self.before_region = self.while_op.before - self.after_region = self.while_op.after - - self.before_region.blocks.append(*self.input_ir_types) - self.before_block = self.before_region.blocks[0] - - self.after_region.blocks.append(*self.input_ir_types) - self.after_block = self.after_region.blocks[0] - - def __enter__(self): - with ir.InsertionPoint(self.before_block): - args = new_from_mlir_values(self.inputs, self.before_block.arguments) - cond = self.condition(*args) - cond_ir_val = extract_mlir_values(cond) - scf.ConditionOp(cond_ir_val[0], [*self.before_block.arguments]) - self.ipoint_op = ir.InsertionPoint(self.after_block) - self.ipoint_op.__enter__() - return new_from_mlir_values(self.inputs, self.after_block.arguments) - - def __exit__(self, exc_type, exc_value, traceback): - self.ipoint_op.__exit__(exc_type, exc_value, traceback) - - @property - def results(self): - return new_from_mlir_values(self.inputs, self.while_op.results_) - - -def while_generate( - inputs: Sequence[Union[ir.Value, Numeric]], - condition: Callable[[Sequence[Union[ir.Value, Numeric]]], Union[ir.Value, Numeric]], - *, - loc=None, - ip=None, -) -> WhileLoopContext: - """ - Generate a WhileLoopContext for a dynamic loop. - """ - return WhileLoopContext(inputs, condition, loc=loc, ip=ip) - - -def equal(lhs, rhs): - if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): - return lhs == rhs - - # Both sequence - if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): - # Short-circuit for unequal length - if len(lhs) != len(rhs): - return False - return all_(equal(l, r) for l, r in zip(lhs, rhs)) - return lhs == rhs - - -def not_equal(lhs, rhs): - if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): - return lhs != rhs - - # Both sequence - if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): - # Short-circuit for unequal length - if len(lhs) != len(rhs): - return True - return any_(not_equal(l, r) for l, r in zip(lhs, rhs)) - - if hasattr(lhs, "__ne__"): - return lhs != rhs - elif hasattr(rhs, "__ne__"): - return rhs != lhs - else: - return not_(equal(lhs, rhs)) - - -def in_(lhs, rhs): - if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): - return lhs in rhs - - if not isinstance(rhs, Sequence): - raise DSLRuntimeError( - f"'in' not supported between instances of {type(lhs)} and {type(rhs)}" - ) - - return any_(equal(lhs, r) for r in rhs) - - -def _lte_gte(lhs, rhs, op): - def native_lte_gte(lhs, rhs, op): - match op: - case "<": - return lhs < rhs - case "<=": - if hasattr(lhs, "__le__"): - return lhs <= rhs - else: - return not_(lhs > rhs) - case ">": - return lhs > rhs - case ">=": - if hasattr(lhs, "__ge__"): - return lhs >= rhs - else: - return not_(lhs < rhs) - case _: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") - - if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): - return native_lte_gte(lhs, rhs, op) - - # Both sequence, comparisons other than == and != do not allow mixing different types of sequences - if ( - isinstance(lhs, Sequence) - and isinstance(rhs, Sequence) - and type(lhs) == type(rhs) - ): - unequal_found = False - comp_results = [] - mask = [] - for l, r in zip(lhs, rhs): - is_equal = equal(l, r) - mask.append(not_(or_(is_equal, unequal_found))) - unequal_found = not_(is_equal) - comp_results.append(_lte_gte(l, r, op)) - - result = any_(and_(r, m) for r, m in zip(comp_results, mask)) - - if len(lhs) != len(rhs): - # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types - # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one - has_valid_mask = any_(mask) - match op: - case "<": - length_result = len(lhs) < len(rhs) - case ">": - length_result = len(lhs) > len(rhs) - case "<=": - length_result = len(lhs) <= len(rhs) - case ">=": - length_result = len(lhs) >= len(rhs) - if type(has_valid_mask) == bool: - return result if has_valid_mask else length_result - else: - return select_(has_valid_mask, result, length_result) - else: - if op in {"<=", ">="}: - # If no unequal, return True - return select_(unequal_found, result, True) - else: - return result - else: - return native_lte_gte(lhs, rhs, op) - - -def greater_than(lhs, rhs): - return _lte_gte(lhs, rhs, ">") - - -def greater_equal(lhs, rhs): - return _lte_gte(lhs, rhs, ">=") - - -def less_than(lhs, rhs): - return _lte_gte(lhs, rhs, "<") - - -def less_equal(lhs, rhs): - return _lte_gte(lhs, rhs, "<=") - - -def _compare_dispatch(lhs, rhs, op): - """ - Dispatches the comparison operation between lhs and rhs based on the given operator. - - :param lhs: The left-hand side operand for the comparison. - :param rhs: The right-hand side operand for the comparison. - :param op: The comparison operator as a string. Supported operators are: - - "is", "is not": Python identity comparisons. - - "in", "not in": Membership tests. - - "==", "!=": Equality and inequality. - - "<", ">", "<=", ">=": Relational comparisons. - :return: The result of the comparison, which may be a boolean or a DSL-specific type. - :raises DSLRuntimeError: If the operator is not supported. - """ - match op: - # 'is' and 'is not' are pure python operators - case "is": - return lhs is rhs - case "is not": - return lhs is not rhs - case "in": - return in_(lhs, rhs) - case "not in": - return not_(in_(lhs, rhs)) - case "==": - return equal(lhs, rhs) - case "!=": - return not_equal(lhs, rhs) - case "<": - return less_than(lhs, rhs) - case ">": - return greater_than(lhs, rhs) - case ">=": - return greater_equal(lhs, rhs) - case "<=": - return less_equal(lhs, rhs) - case _: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") - - -def _compare_executor(left, comparators, ops): - # Fast path for single comparison - if len(comparators) == 1: - return _compare_dispatch(left, comparators[0], ops[0]) - - # Chain comparison, dispatch in a loop - result = True - current = left - for comparator, op in zip(comparators, ops): - cmp_result = _compare_dispatch(current, comparator, op) - result = and_(result, cmp_result) - current = comparator - - return result - - -def _builtin_redirector(fcn): - if fcn == builtins.max: - return max - elif fcn == builtins.min: - return min - elif fcn == builtins.any: - return any_ - elif fcn == builtins.all: - return all_ - else: - raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") - - -# ============================================================================= -# Set the AST decorator -# ============================================================================= - -# Set the DSL specific functions -executor.set_functions( - is_dynamic_expression=is_dynamic_expression, - loop_execute_range_dynamic=_loop_execute_range_dynamic, - if_dynamic=_if_execute_dynamic, - while_dynamic=_while_execute_dynamic, - compare_executor=_compare_executor, - any_executor=any_, - all_executor=all_, - builtin_redirector=_builtin_redirector, -) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py deleted file mode 100644 index b5b4d8953d69b4100871a496623f051d60ab2a8d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py +++ /dev/null @@ -1,633 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import List, Tuple -from types import NoneType -from cutlass._mlir import ir -from cutlass._mlir.dialects import scf, arith -from cutlass._mlir.extras import types as T -from collections.abc import Sequence - -from ..base_dsl.dsl import is_dynamic_expression -from ..base_dsl.ast_helpers import * -from ..base_dsl.utils.logger import log -from ..base_dsl import typing as t -from ..base_dsl.typing import ( - Int32, - Float32, - Boolean, - Numeric, - get_mlir_types, - as_numeric, -) -from . import cutlass as cutlass_dsl -from .tree_utils import PyTreeDef, check_tree_equal - -# ============================================================================= -# AST Helpers -# ============================================================================= - - -class LoopUnroll(ir.Attribute): - def __init__(self, **kwargs): - valid_keys = set(["count", "full"]) - def to_mlir_attr(val): - if isinstance(val, bool): - return "true" if val else "false" - elif isinstance(val, int): - return f"{val} : i32" - else: - raise DSLNotImplemented(f"{type(val)} is not supported") - - cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} - if kwargs.get("count", None) == 1: - cfg["disable"] = "true" - - unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" - - super().__init__( - ir.Attribute.parse(f"#llvm.loop_annotation") - ) - - -class ScfGenerator: - """ - Encapsulates common scf dialect functionality: pack, unpack, and SCF execution. - """ - - def __init__(self): - pass - - @staticmethod - def _normalize_region_result_to_list(region_result: Any) -> List[Any]: - """ - Convert region_result to a list if it is not already a list - If region_result is a list, return it as is. - If region_result is None, return an empty list. - If region_result is not a list, return a list containing region_result as the only element. - """ - if region_result is None: - region_result_list = [] - elif not isinstance(region_result, list): - region_result_list = [region_result] - else: - region_result_list = region_result - return region_result_list - - @staticmethod - def _check_region_result(original_value, region_value, arg_name, op_type_name): - """ - Validate that a region result maintains the same type as the original value. - - This method checks for type consistency between the original value passed to a dynamic - SCF operation (like for, if, while) and the value returned from the operation's region. - - Args: - original_value: The value before entering the SCF operation region - region_value: The value returned from the SCF operation region - arg_name: Name of the argument being checked (for error reporting) - op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting - - Raises: - DSLRuntimeError: If the region value has a different type than the original value. - The error includes suggestions for using compile-time control flow instead. - - Note: - This method performs relaxed type checking that allows inheritance relationships. - For example, a child class can be returned where a parent class was expected. - However, fundamental type changes (like None to non-None, different sequence types, - or different numeric types) are not allowed in dynamic SCF operations. - """ - - def get_type_name(value): - if isinstance(value, NoneType): - return "None" - elif isinstance(value, Sequence): - return f"{type(value).__name__}<{len(value)}>" - else: - return type(value).__name__ - - # Check for type mismatches - type_mismatch = False - old_type_name = None - new_type_name = None - - # Handle None type changes - if isinstance(original_value, NoneType) != isinstance(region_value, NoneType): - type_mismatch = True - old_type_name = get_type_name(original_value) - new_type_name = get_type_name(region_value) - # Handle sequence type/length changes - elif isinstance(original_value, Sequence) and isinstance( - region_value, Sequence - ): - if type(original_value) != type(region_value) or len(original_value) != len( - region_value - ): - type_mismatch = True - old_type_name = get_type_name(original_value) - new_type_name = get_type_name(region_value) - # Handle numeric type changes - elif isinstance( - original_value, (Numeric, ArithValue, ir.Value, int, float, bool) - ) or isinstance( - region_value, (Numeric, ArithValue, ir.Value, int, float, bool) - ): - try: - original_numeric = as_numeric(original_value) - region_numeric = as_numeric(region_value) - if original_numeric.dtype != region_numeric.dtype: - type_mismatch = True - old_type_name = original_numeric.dtype.__name__ - new_type_name = region_numeric.dtype.__name__ - except Exception: - pass - # Handle general type changes (relaxed for inheritance) - elif type(original_value) != type(region_value): - old_type = type(original_value) - new_type = type(region_value) - if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)): - type_mismatch = True - old_type_name = old_type.__name__ - new_type_name = new_type.__name__ - - if type_mismatch: - raise DSLRuntimeError( - f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, " - f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.", - suggestion=( - f"Please avoid changing type inside a dynamic `{op_type_name}`, " - f"or change to compile-time control flow by marking this `{op_type_name}` with " - f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." - ), - ) - - def scf_execute_dynamic( - self, - op_type_name: str, - mix_iter_args: List[Any], - full_write_args_count: int, - mix_iter_arg_names: List[str], - create_op_func: Callable[[List[ir.Value]], ir.Operation], - region_builders: List[ - Callable[ - [ - "ir.Operation", - List["ir.Value"], # block_args - List["ir.Value"], # dyn_yield_ops - PyTreeDef, - List[Any], - int, - ], - Any, - ] - ], - # block_term_op_builder[region_builder] = scf_op_builder - # e.g. scf.ConditionOp for while loop - block_term_op_builder: Dict[Callable, Callable] = {}, - ) -> Any: - # 1) Unpack - ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue( - mix_iter_args, op_type_name, full_write_args_count - ) - # 2) Create the SCF op - op = create_op_func(ir_values) - log().debug("Generated scf.%s \n[%s]", op_type_name, op) - - # 3) Build the regions - for i, builder in enumerate(region_builders): - region = op.regions[i] - block = region.blocks[0] - with ir.InsertionPoint(block): - block_args = list(block.arguments) - region_result = builder( - op, - block_args, - ir_values, - pytree_def, - mix_iter_args, - full_write_args_count, - ) - - # Use custom terminator if provided for this builder, otherwise use default YieldOp - if builder in block_term_op_builder: - # Use the provided terminator generator - block_term_op_builder[builder](region_result, full_write_args_count) - else: - # Normalize region_result - region_result_list = ScfGenerator._normalize_region_result_to_list( - region_result - ) - # For standard yield op, check result - for arg, result, name in zip( - mix_iter_args, - region_result_list, - mix_iter_arg_names, - ): - ScfGenerator._check_region_result( - arg, result, name, op_type_name - ) - - # Default behavior - generate YieldOp - region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue( - region_result_list, op_type_name, full_write_args_count - ) - - mismatch = check_tree_equal(pytree_def, yield_pytree_def) - if mismatch != -1: - # Get arg name - filterd_arg_names = ( - cutlass_dsl.filter_readonly_frozen_dataclass_names( - mix_iter_args, mix_iter_arg_names, full_write_args_count - ) - ) - - raise DSLRuntimeError( - f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.", - suggestion=( - f"Please avoid changing type structure inside a dynamic `{op_type_name}`, " - f"or change to compile-time control flow by marking this `{op_type_name}` with " - f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." - ), - ) - - scf.YieldOp(region_values) - - log().debug("Completed scf.%s \n[%s]", op_type_name, op) - - # 4) Pack final results - final_results = cutlass_dsl.pack_from_irvalue( - op.results, pytree_def, mix_iter_args, full_write_args_count - ) - - # 5) Return in a nice pattern - if not final_results: - return - if len(final_results) == 1: - return final_results[0] - return final_results - - -def _attr_const_check(attr, expected_type, attr_name): - # Use strict type equality to prevent `bool` being accepted where `int` is required. - if is_dynamic_expression(attr) or type(attr) is not expected_type: - raise DSLRuntimeError( - f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`." - ) - - -def _loop_execute_range_dynamic( - func: Callable, - start: Any, - stop: Any, - step: Any, - mix_iter_args: List[Any] = [], - full_write_args_count: int = 0, - mix_iter_arg_names: List[str] = [], - unroll: int = -1, - unroll_full: bool = False, - prefetch_stages: int = None, -): - """ - Example: build an scf.for with optional unroll, using our universal helper. - """ - scf_gen = ScfGenerator() - - def create_for_op(dyn_yield_ops: List[ir.Value]): - for d in dyn_yield_ops: - if not isinstance(d, ir.Value): - raise DSLRuntimeError( - f"Invalid dyn_yield_ops: {dyn_yield_ops} \n\tExpected ir.Value, got {type(d)}" - ) - - # Convert Python ints or values to IR constants if needed - start_ = t.as_numeric(start) - stop_ = t.as_numeric(stop) - step_ = t.as_numeric(step) - assert start_ is not t.Int32, "Start is required for scf.for" - assert stop_ is not t.Int32, "Stop is required for scf.for" - assert step_ is not t.Int32, "Step is required for scf.for" - start_ = start_.ir_value() - stop_ = stop_.ir_value() - step_ = step_.ir_value() - - # Attributes must be pure Python value, add a check - _attr_const_check(unroll, int, "unroll") - _attr_const_check(unroll_full, bool, "unroll_full") - - # Possibly attach unroll attributes - unroll_attr = None - if unroll_full: - unroll_attr = LoopUnroll(full=True) - elif unroll != -1: - unroll_attr = LoopUnroll(count=unroll) - log().debug("Unroll attribute: %s", unroll_attr) - - prefetch_stages_attr = None - if prefetch_stages is not None: - _attr_const_check(prefetch_stages, int, "prefetch_stages") - if prefetch_stages >= 0: - prefetch_stages_attr = ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), prefetch_stages - ) - else: - raise DSLRuntimeError( - f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`." - ) - log().debug("prefetch_stages attribute: %s", prefetch_stages_attr) - - log().debug( - "Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s", - start_, - type(start_), - stop_, - type(stop_), - step_, - type(step_), - ) - # Create scf.ForOp, passing iteration args if any - try: - if not dyn_yield_ops: - for_op = scf.ForOp(start_, stop_, step_) - else: - for_op = scf.ForOp(start_, stop_, step_, list(dyn_yield_ops)) - except Exception as e: - yield_ops = "\n".join( - f"\t\t{i} => {d} : type : {type(d)}" - for i, d in enumerate(dyn_yield_ops) - ) - raise DSLRuntimeError( - f"Failed to create scf.ForOp \n\t\tstart={start_}: type : {type(start_)}" - f"\n\t\tstop={stop_}: type : {type(stop_)}\n\t\tstep={step_}: type : {type(step_)}" - f", \n\tdyn_yield_ops:\n{yield_ops}" - ) from e - - if unroll_attr is not None: - for_op.attributes["loop_annotation"] = unroll_attr - - if prefetch_stages_attr is not None: - for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr - - return for_op - - def for_body_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - # scf.ForOp block_args are typically [induction_var, iter_args...] - # But MLIR also gives you op.induction_variable - iv = t.as_numeric(op.induction_variable) - log().debug( - "For body builder: %s block_args: %s full_write_args_count: %s", - iv, - block_args, - full_write_args_count, - ) - # block_args[1:] are iteration variables - func_args = [] - func_args.extend( - cutlass_dsl.pack_from_irvalue( - block_args[1:], pytree_def, mix_iter_args, full_write_args_count - ) - ) - if not func_args: - # No iteration arguments, or only the induction var - func(iv) - return [] # yield nothing - else: - updated_func_args = func(iv, *func_args) - return updated_func_args - - # Now call the universal SCF executor with a single region builder - return scf_gen.scf_execute_dynamic( - op_type_name="for", - mix_iter_args=mix_iter_args, - full_write_args_count=full_write_args_count, - mix_iter_arg_names=mix_iter_arg_names, - create_op_func=create_for_op, - region_builders=[for_body_builder], - ) - - -def _if_execute_dynamic( - pred: "ir.Value", - then_block: Callable, - else_block: Callable = None, - mix_yield_args: List[Any] = [], - full_write_args_count: int = 0, - mix_yield_arg_names: List[str] = [], - if_constexpr=None, # ignoring for brevity -): - """ - Build an scf.if with optional else, using our universal helper. - """ - scf_gen = ScfGenerator() - - def create_if_op(dyn_yield_ops: List[ir.Value]): - # Assume final result types match the dynamic yields - result_types = [arg.type for arg in dyn_yield_ops] - - pred_ = Boolean(pred) - - try: - if_op = scf.IfOp( - pred_.ir_value(), - hasElse=(else_block is not None), - results_=result_types, - ) - except Exception as e: - raise DSLRuntimeError( - f"Failed to create scf.IfOp \n\t\tpred={pred_}: type : {type(pred_)}" - ) from e - return if_op - - def then_builder( - if_op, - _, - dyn_yield_ops, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - flat_args = [] - flat_args.extend( - cutlass_dsl.pack_from_irvalue( - dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count - ) - ) - return then_block(*flat_args) - - region_builders = [then_builder] - - if else_block is not None: - - def else_builder( - if_op, - _, - dyn_yield_ops, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - flat_args = [] - flat_args.extend( - cutlass_dsl.pack_from_irvalue( - dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count - ) - ) - return else_block(*flat_args) - - region_builders.append(else_builder) - - return scf_gen.scf_execute_dynamic( - op_type_name="if", - mix_iter_args=mix_yield_args, - full_write_args_count=full_write_args_count, - mix_iter_arg_names=mix_yield_arg_names, - create_op_func=create_if_op, - region_builders=region_builders, - ) - - -def _while_execute_dynamic( - while_before_block: Callable, - while_after_block: Callable = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): - """ - Create and return an SCF WhileOp for dynamic loops. - Generate the dynamic loop body using SCF WhileOp. - - Args: - while_before_block: Function that returns (condition, updated_values) - while_after_block: Function that returns updated values - write_args: Values that are updated in the loop - - See create_while_function in ast_preprocessor.py for details on the input structure. - """ - log().debug("_while_execute_dynamic") - while_op_type_name = "while" - scf_gen = ScfGenerator() - - def create_while_op(dyn_yield_ops: List[ir.Value]): - # Create the while operation with the types from yield_args - result_types = [arg.type for arg in dyn_yield_ops] - try: - while_op = scf.WhileOp(result_types, dyn_yield_ops) - while_op.before.blocks.append(*result_types) - while_op.after.blocks.append(*result_types) - log().debug("[%s]", while_op) - return while_op - except Exception as e: - yield_ops = "\n".join( - f"\t\t{i} => {d} : type : {type(d)}" - for i, d in enumerate(dyn_yield_ops) - ) - raise DSLRuntimeError( - f"Failed to create scf.WhileOp with yield_ops:\n{yield_ops}" - ) from e - - def before_block_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - # Build the before (condition) block - flat_args = [] - flat_args.extend( - cutlass_dsl.pack_from_irvalue( - block_args, pytree_def, mix_iter_args, full_write_args_count - ) - ) - - log().debug("before block args: %s", flat_args) - - cond, before_results = while_before_block(*flat_args) - - if not isinstance(before_results, (list, ir.OpResultList)): - before_results = [before_results] - - log().debug("cond [%s]", cond) - log().debug( - "before_results [%s]", - before_results, - ) - - return cond, before_results - - def before_block_terminator(cond_and_results, full_write_args_count): - # Generate a condition op instead of yield op - cond = cond_and_results[0] - before_result_list = ScfGenerator._normalize_region_result_to_list( - cond_and_results[1] - ) - ir_cond = as_numeric(cond).ir_value() - ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue( - before_result_list, while_op_type_name, full_write_args_count - ) - log().debug( - "creating scf.ConditionOp with [%s], [%s]", - ir_cond, - ir_results_list, - ) - scf.ConditionOp(ir_cond, ir_results_list) - - def after_block_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - # Build the after (body) block - flat_args = [] - flat_args.extend( - cutlass_dsl.pack_from_irvalue( - block_args, pytree_def, mix_iter_args, full_write_args_count - ) - ) - - log().debug("after block args: %s", flat_args) - - after_results = while_after_block(*flat_args) - - if not isinstance(after_results, (list, ir.OpResultList)): - after_results = [after_results] - - log().debug( - "after_results [%s]", - after_results, - ) - - return after_results - - # Call the universal SCF executor with two region builders - return scf_gen.scf_execute_dynamic( - op_type_name=while_op_type_name, - mix_iter_args=write_args, - full_write_args_count=full_write_args_count, - mix_iter_arg_names=write_args_names, - create_op_func=create_while_op, - region_builders=[before_block_builder, after_block_builder], - block_term_op_builder={ - before_block_builder: before_block_terminator - }, # Only customize the before block - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py deleted file mode 100644 index 599b72ea5c6b1d378480ceeb1d43d14fd58b569d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py +++ /dev/null @@ -1,763 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin -import dataclasses -import itertools as it -from types import SimpleNamespace - -from ..base_dsl.typing import as_numeric, Numeric, Constexpr -from ..base_dsl._mlir_helpers.arith import ArithValue -from ..base_dsl.common import DSLBaseError -from .._mlir import ir - -# ============================================================================= -# Tree Utils -# ============================================================================= - - -class DSLTreeFlattenError(DSLBaseError): - """Exception raised when tree flattening fails due to unsupported types.""" - - def __init__(self, msg: str, type_str: str): - super().__init__(msg) - self.type_str = type_str - - -def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]: - """Unzip a sequence of pairs into two lists.""" - lst1, lst2 = [], [] - for x1, x2 in pairs: - lst1.append(x1) - lst2.append(x2) - return lst1, lst2 - - -def get_fully_qualified_class_name(x: Any) -> str: - """ - Get the fully qualified class name of an object. - - Args: - x: Any object - - Returns: - str: Fully qualified class name in format 'module.class_name' - - Example: - >>> get_fully_qualified_class_name([1, 2, 3]) - 'builtins.list' - """ - return f"{x.__class__.__module__}.{x.__class__.__qualname__}" - - -def is_frozen_dataclass(obj_or_cls: Any) -> bool: - """ - Check if an object or class is a frozen dataclass. - - Args: - obj_or_cls: Either a dataclass instance or class - - Returns: - bool: True if the object/class is a dataclass declared with frozen=True, - False otherwise - - Example: - >>> from dataclasses import dataclass - >>> @dataclass(frozen=True) - ... class Point: - ... x: int - ... y: int - >>> is_frozen_dataclass(Point) - True - >>> is_frozen_dataclass(Point(1, 2)) - True - """ - cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__ - - return ( - dataclasses.is_dataclass(cls) - and getattr(cls, "__dataclass_params__", None) is not None - and cls.__dataclass_params__.frozen - ) - - -def is_dynamic_expression(x: Any) -> bool: - """ - Check if an object implements the DynamicExpression protocol. - - Objects implementing this protocol must have both `__extract_mlir_values__` - and `__new_from_mlir_values__` methods. - - Args: - x: Any object to check - - Returns: - bool: True if the object implements the DynamicExpression protocol, - False otherwise - """ - return all( - hasattr(x, attr) - for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") - ) - - -def is_constexpr_field(field: dataclasses.Field) -> bool: - """ - Check if a field is a constexpr field. - """ - if field.type is Constexpr: - return True - elif get_origin(field.type) is Constexpr: - return True - return False - - -# ============================================================================= -# PyTreeDef -# ============================================================================= - -class NodeType(NamedTuple): - """ - Represents a node in a pytree structure. - - Attributes: - name: String representation of the node type - to_iterable: Function to convert node to iterable form - from_iterable: Function to reconstruct node from iterable form - """ - name: str - to_iterable: Callable - from_iterable: Callable - - -class PyTreeDef(NamedTuple): - """ - Represents the structure definition of a pytree. - - Attributes: - node_type: The type of this node - node_metadata: SimpleNamespace metadata associated with this node - child_treedefs: Tuple of child tree definitions - """ - node_type: NodeType - node_metadata: SimpleNamespace - child_treedefs: tuple["PyTreeDef", ...] - - -@dataclasses.dataclass(frozen=True) -class Leaf: - """ - Represents a leaf node in a pytree structure. - - Attributes: - is_numeric: Whether this leaf contains a `Numeric` value - is_none: Whether this leaf represents None - node_metadata: SimpleNamespace metadata associated with this leaf - ir_type_str: String representation of the IR type - """ - is_numeric: bool = False - is_none: bool = False - node_metadata: SimpleNamespace = None - ir_type_str: str = None - - -# ============================================================================= -# Default to_iterable and from_iterable -# ============================================================================= - - -def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: - """ - Extract non-method, non-function attributes from a dataclass instance. - - Args: - x: A dataclass instance - - Returns: - tuple: (field_names, field_values) lists - """ - fields = [field.name for field in dataclasses.fields(x)] - - # If the dataclass has extra fields, raise an error - for k in x.__dict__.keys(): - if k not in fields: - raise DSLTreeFlattenError( - f"`{x}` has extra field `{k}`", - type_str=get_fully_qualified_class_name(x), - ) - - if not fields: - return [], [] - - # record constexpr fields - members = [] - constexpr_fields = [] - for field in dataclasses.fields(x): - if is_constexpr_field(field): - constexpr_fields.append(field.name) - fields.remove(field.name) - v = getattr(x, field.name) - if is_dynamic_expression(v): - raise DSLTreeFlattenError( - f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`", - type_str=get_fully_qualified_class_name(x), - ) - else: - members.append(getattr(x, field.name)) - - return fields, members, constexpr_fields - - -def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: - """ - Convert a dataclass instance to iterable form for tree flattening. - - Extracts all non-method, non-function attributes that don't start with '__' - and returns them along with metadata about the dataclass. - - Args: - x: A dataclass instance - - Returns: - tuple: (metadata, members) where metadata contains type info and field names, - and members is the list of attribute values - """ - fields, members, constexpr_fields = extract_dataclass_members(x) - - metadata = SimpleNamespace( - type_str=get_fully_qualified_class_name(x), - fields=fields, - constexpr_fields=constexpr_fields, - original_obj=x, - ) - return metadata, members - - -def set_dataclass_attributes( - instance: Any, - fields: list[str], - values: Iterable[Any], - constexpr_fields: list[str], -) -> Any: - """ - Set attributes on a dataclass instance. - - Args: - instance: The dataclass instance - fields: List of field names - values: Iterable of field values - is_frozen: Whether the dataclass is frozen - - Returns: - The instance with attributes set - """ - if not fields: - return instance - - kwargs = dict(zip(fields, values)) - for field in constexpr_fields: - kwargs[field] = getattr(instance, field) - return dataclasses.replace(instance, **kwargs) - -def default_dataclass_from_iterable( - metadata: SimpleNamespace, children: Iterable[Any] -) -> Any: - """ - Reconstruct a dataclass instance from iterable form. - - Handles both regular and frozen dataclasses appropriately. - - Args: - metadata: Metadata containing type information and field names - children: Iterable of attribute values to reconstruct the instance - - Returns: - The reconstructed dataclass instance - """ - instance = metadata.original_obj - - new_instance = set_dataclass_attributes( - instance, metadata.fields, children, metadata.constexpr_fields - ) - metadata.original_obj = new_instance - return new_instance - - -def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: - """ - Convert a dynamic expression to iterable form. - - Uses the object's `__extract_mlir_values__` method to extract MLIR values. - - Args: - x: A dynamic expression object - - Returns: - tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression - and mlir_values are the extracted MLIR values - """ - return ( - SimpleNamespace(is_dynamic_expression=1, original_obj=x), - x.__extract_mlir_values__(), - ) - - -def dynamic_expression_from_iterable( - metadata: SimpleNamespace, children: Iterable[Any] -) -> Any: - """ - Reconstruct a dynamic expression from iterable form. - - Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values. - - Args: - metadata: Metadata containing the original object - children: Iterable of MLIR values to reconstruct from - - Returns: - The reconstructed dynamic expression object - """ - return metadata.original_obj.__new_from_mlir_values__(list(children)) - - -def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: - """ - Convert a dict to iterable form. - """ - if isinstance(x, SimpleNamespace): - keys = list(x.__dict__.keys()) - values = list(x.__dict__.values()) - else: - keys = list(x.keys()) - values = list(x.values()) - - return ( - SimpleNamespace( - type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys - ), - values, - ) - - -def default_dict_from_iterable( - metadata: SimpleNamespace, children: Iterable[Any] -) -> Any: - """ - Reconstruct a dict from iterable form. - """ - instance = metadata.original_obj - fields = metadata.fields - is_simple_namespace = isinstance(instance, SimpleNamespace) - - for k, v in zip(fields, children): - if is_simple_namespace: - setattr(instance, k, v) - else: - instance[k] = v - - return instance - - -# ============================================================================= -# Register pytree nodes -# ============================================================================= - -_node_types: dict[type, NodeType] = {} - - -def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType: - """ - Register a new node type for pytree operations. - - Args: - ty: The type to register - to_iter: Function to convert instances of this type to iterable form - from_iter: Function to reconstruct instances of this type from iterable form - - Returns: - NodeType: The created NodeType instance - """ - nt = NodeType(str(ty), to_iter, from_iter) - _node_types[ty] = nt - return nt - - -def register_default_node_types() -> None: - """Register default node types for pytree operations.""" - default_registrations = [ - ( - tuple, - lambda t: (SimpleNamespace(length=len(t)), list(t)), - lambda _, xs: tuple(xs), - ), - ( - list, - lambda l: (SimpleNamespace(length=len(l)), list(l)), - lambda _, xs: list(xs), - ), - ( - dict, - default_dict_to_iterable, - default_dict_from_iterable, - ), - ( - SimpleNamespace, - default_dict_to_iterable, - default_dict_from_iterable, - ), - ] - - for ty, to_iter, from_iter in default_registrations: - register_pytree_node(ty, to_iter, from_iter) - - -# Initialize default registrations -register_default_node_types() - - -# ============================================================================= -# tree_flatten and tree_unflatten -# ============================================================================= - -""" -Behavior of tree_flatten and tree_unflatten, for example: - -```python - a = (1, 2, 3) - b = MyClass(a=1, b =[1,2,3]) -``` - -yields the following tree: - -```python - tree_a = PyTreeDef(type = 'tuple', - metadata = {length = 3}, - children = [ - Leaf(type = int), - Leaf(type = int), - Leaf(type = int), - ], - ) - flattened_a = [1, 2, 3] - tree_b = PyTreeDef(type = 'MyClass', - metadata = {fields = ['a','b']}, - children = [ - PyTreeDef(type = `list`, - metadata = {length = 3}, - children = [ - Leaf(type=`int`), - Leaf(type=`int`), - Leaf(type=`int`), - ], - ), - Leaf(type=int), - ], - ) - flattened_b = [1, 1, 2, 3] -``` - -Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure. - -``` python - unflattened_a = tree_unflatten(tree_a, flattened_a) - unflattened_b = tree_unflatten(tree_b, flattened_b) -``` - -yields the following structure: - -``` python - unflattened_a = (1, 2, 3) - unflattened_b = MyClass(a=1, b =[1,2,3]) -``` - -unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b. - -""" - - -def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]: - """ - Flatten a nested structure into a flat list of values and a tree definition. - - This function recursively traverses nested data structures (trees) and - flattens them into a linear list of leaf values, while preserving the - structure information in a PyTreeDef. - - Args: - x: The nested structure to flatten - - Returns: - tuple: (flat_values, treedef) where flat_values is a list of leaf values - and treedef is the tree structure definition - - Raises: - DSLTreeFlattenError: If the structure contains unsupported types - - Example: - >>> tree_flatten([1, [2, 3], 4]) - ([1, 2, 3, 4], PyTreeDef(...)) - """ - children_iter, treedef = _tree_flatten(x) - return list(children_iter), treedef - - -def get_registered_node_types_or_insert(x: Any) -> NodeType | None: - """ - Get the registered node type for an object, registering it if necessary. - - This function checks if a type is already registered for pytree operations. - If not, it automatically registers the type based on its characteristics: - - Dynamic expressions get registered with dynamic expression handlers - - Dataclasses get registered with default dataclass handlers - - Args: - x: The object to get or register a node type for - - Returns: - NodeType or None: The registered node type, or None if the type - cannot be registered - """ - node_type = _node_types.get(type(x)) - if node_type: - return node_type - elif is_dynamic_expression(x): - # If a class implements DynamicExpression protocol, register it before default dataclass one - return register_pytree_node( - type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable - ) - elif dataclasses.is_dataclass(x): - return register_pytree_node( - type(x), default_dataclass_to_iterable, default_dataclass_from_iterable - ) - else: - return None - - -def create_leaf_for_value( - x: Any, - is_numeric: bool = False, - is_none: bool = False, - node_metadata: SimpleNamespace = None, - ir_type_str: str = None, -) -> Leaf: - """ - Create a Leaf node for a given value. - - Args: - x: The value to create a leaf for - is_numeric: Whether this is a numeric value - is_none: Whether this represents None - node_metadata: Optional metadata - ir_type_str: Optional IR type string - - Returns: - Leaf: The created leaf node - """ - return Leaf( - is_numeric=is_numeric, - is_none=is_none, - node_metadata=node_metadata, - ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None), - ) - - -def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: - """ - Internal function to flatten a tree structure. - - This is the core implementation of tree flattening that handles different - types of objects including None, ArithValue, ir.Value, Numeric types, - and registered pytree node types. - - Args: - x: The object to flatten - - Returns: - tuple: (flattened_values, treedef) where flattened_values is an iterable - of leaf values and treedef is the tree structure - - Raises: - DSLTreeFlattenError: If the object type is not supported - """ - match x: - case None: - return [], create_leaf_for_value(x, is_none=True) - - case ArithValue() if is_dynamic_expression(x): - v = x.__extract_mlir_values__() - return v, create_leaf_for_value( - x, - node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), - ) - - case ArithValue(): - return [x], create_leaf_for_value(x, is_numeric=True) - - case ir.Value(): - return [x], create_leaf_for_value(x) - - case Numeric(): - v = x.__extract_mlir_values__() - return v, create_leaf_for_value( - x, - node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), - ) - - case _: - node_type = get_registered_node_types_or_insert(x) - if node_type: - node_metadata, children = node_type.to_iterable(x) - children_flat, child_trees = unzip2(map(_tree_flatten, children)) - flattened = it.chain.from_iterable(children_flat) - return flattened, PyTreeDef( - node_type, node_metadata, tuple(child_trees) - ) - - # Try to convert to numeric - try: - nval = as_numeric(x).ir_value() - return [nval], create_leaf_for_value(nval, is_numeric=True) - except Exception: - raise DSLTreeFlattenError( - "Flatten Error", get_fully_qualified_class_name(x) - ) - - -def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: - """ - Reconstruct a nested structure from a flat list of values and tree definition. - - This is the inverse operation of tree_flatten. It takes the flattened - values and the tree structure definition to reconstruct the original - nested structure. - - Args: - treedef: The tree structure definition from tree_flatten - xs: List of flat values to reconstruct from - - Returns: - The reconstructed nested structure - - Example: - >>> flat_values, treedef = tree_flatten([1, [2, 3], 4]) - >>> tree_unflatten(treedef, flat_values) - [1, [2, 3], 4] - """ - return _tree_unflatten(treedef, iter(xs)) - - -def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: - """ - Internal function to reconstruct a tree structure. - - This is the core implementation of tree unflattening that handles - different types of tree definitions including Leaf nodes and PyTreeDef nodes. - - Args: - treedef: The tree structure definition - xs: Iterator of flat values to reconstruct from - - Returns: - The reconstructed object - """ - match treedef: - case Leaf(is_none=True): - return None - - case Leaf( - node_metadata=metadata - ) if metadata and metadata.is_dynamic_expression: - return metadata.original_obj.__new_from_mlir_values__([next(xs)]) - - case Leaf(is_numeric=True): - return as_numeric(next(xs)) - - case Leaf(): - return next(xs) - - case PyTreeDef(): - children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) - return treedef.node_type.from_iterable(treedef.node_metadata, children) - - -def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: - """ - Check if two tree definitions are structurally equal. - - This is a helper function for check_tree_equal that recursively compares - tree structures. - - Args: - lhs: Left tree definition (PyTreeDef or Leaf) - rhs: Right tree definition (PyTreeDef or Leaf) - - Returns: - bool: True if the trees are structurally equal, False otherwise - """ - match (lhs, rhs): - case (Leaf(), Leaf()): - return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str - - case (PyTreeDef(), PyTreeDef()): - lhs_metadata = lhs.node_metadata - rhs_metadata = rhs.node_metadata - - lhs_fields = getattr(lhs_metadata, "fields", []) - rhs_fields = getattr(rhs_metadata, "fields", []) - lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) - rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) - - return ( - lhs.node_type == rhs.node_type - and lhs_fields == rhs_fields - and lhs_constexpr_fields == rhs_constexpr_fields - and len(lhs.child_treedefs) == len(rhs.child_treedefs) - and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) - ) - - case _: - return False - - -def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: - """ - Check if two tree definitions are equal and return the index of first difference. - - This function compares two tree definitions and returns the index of the - first child that differs, or -1 if they are completely equal. - - Args: - lhs: Left tree definition - rhs: Right tree definition - - Returns: - int: Index of the first differing child, or -1 if trees are equal - - Example: - >>> treedef1 = tree_flatten([1, [2, 3]])[1] - >>> treedef2 = tree_flatten([1, [2, 4]])[1] - >>> check_tree_equal(treedef1, treedef2) - 1 # The second child differs - """ - assert len(lhs.child_treedefs) == len(rhs.child_treedefs) - - def find_first_difference( - index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]] - ) -> int: - index, (l, r) = index_and_pair - return index if not _check_tree_equal(l, r) else -1 - - differences = map( - find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs)) - ) - return next((diff for diff in differences if diff != -1), -1) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py deleted file mode 100644 index 9bdd259c0203aaca3c7a7e31e64a576630f369a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py +++ /dev/null @@ -1,213 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -import logging -import os -import sys - -import cutlass_library - - -def _cuda_install_path_from_nvcc() -> str: - import subprocess - # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC - result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True) - if result.returncode != 0: - raise Exception(f'Unable to find nvcc via `which` utility.') - - cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0] - if not os.path.isdir(cuda_install_path): - raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, ' - f'and default path of {cuda_install_path} does not exist.') - - return cuda_install_path - - -CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path) - -# Alias CUTLASS_PATH as source_path -source_path = CUTLASS_PATH - -_NVCC_VERSION = None -def nvcc_version(): - global _NVCC_VERSION - if _NVCC_VERSION is None: - import subprocess - - # Attempt to get NVCC version - result = subprocess.run(['nvcc', '--version'], capture_output=True) - if result.returncode != 0: - raise Exception('Unable to run `nvcc --version') - _NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0] - return _NVCC_VERSION - -_CUDA_INSTALL_PATH = None -def cuda_install_path(): - """ - Helper method for on-demand fetching of the CUDA installation path. This allows - the import of CUTLASS to proceed even if NVCC is not available, preferring to - raise this error only when an operation that needs NVCC is being performed. - """ - global _CUDA_INSTALL_PATH - if _CUDA_INSTALL_PATH is None: - _CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) - return _CUDA_INSTALL_PATH - -CACHE_FILE = "compiled_cache.db" - -from cutlass_library import ( - DataType, - EpilogueScheduleType, - KernelScheduleType, - MathOperation, - LayoutType, - OpcodeClass, - TileDescription, - TileSchedulerType, -) - -this = sys.modules[__name__] -this.logger = logging.getLogger(__name__) - -# RMM is only supported for Python 3.9+ -if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3: - try: - import rmm - this.use_rmm = True - except ImportError: - this.use_rmm = False -else: - this.use_rmm = False - - -def set_log_level(level: int): - """ - Sets the log level - - :param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options - :type log_level: int - """ - this.logger.setLevel(level) - -set_log_level(logging.ERROR) - -from cutlass_cppgen.library_defaults import OptionRegistry -from cutlass_cppgen.backend.utils.device import device_cc - -this._option_registry = None -def get_option_registry(): - """ - Helper method for on-demand initialization of the options registry. This avoids building - the registry when CUTLASS is imported. - """ - if this._option_registry is None: - this.logger.info("Initializing option registry") - this._option_registry = OptionRegistry(device_cc()) - return this._option_registry - -this.__version__ = '4.2.1' - -from cutlass_cppgen.backend import create_memory_pool -from cutlass_cppgen.emit.pytorch import pytorch -from cutlass_cppgen.op.gemm import Gemm -from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad -from cutlass_cppgen.op.gemm_grouped import GroupedGemm -from cutlass_cppgen.op.op import OperationBase -from cutlass_cppgen.backend.evt.ir.tensor import Tensor -from cutlass_cppgen.utils.lazy_import import lazy_import - - -this.memory_pool = None -def get_memory_pool(): - """" - Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily - whe CUTLASS is imported. - """ - if this.use_rmm and this.memory_pool is None: - this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) - return this.memory_pool - - -base_cuda = lazy_import("cuda") -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") - -this._device_id = None -this._nvcc_version = None - -def check_cuda_versions(): - # Strip any additional information from the CUDA version - _cuda_version = base_cuda.__version__.split("rc")[0] - # Check that Python CUDA version exceeds NVCC version - this._nvcc_version = nvcc_version() - _cuda_list = _cuda_version.split('.') - _nvcc_list = this._nvcc_version.split('.') - for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): - if int(val_cuda) < int(val_nvcc): - raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") - - if len(_nvcc_list) > len(_cuda_list): - if len(_nvcc_list) != len(_cuda_list) + 1: - raise Exception(f"Malformatted NVCC version of {this._nvcc_version}") - if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: - raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") - -def initialize_cuda_context(): - check_cuda_versions() - - if this._device_id is not None: - return - - if this.use_rmm: - # This also covers initializing the CUDA context - get_memory_pool() - - device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID") - if device_id is None: - if not this.use_rmm: - # Manually call cuInit() and create context by making a runtime API call - err, = cudart.cudaFree(0) - if err != cudart.cudaError_t.cudaSuccess: - raise RuntimeError(f"cudaFree failed with error {err}") - - err, device_count = cuda.cuDeviceGetCount() - if err != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"cuDeviceGetCount failed with error {err}") - if device_count <= 0: - raise Exception("No CUDA devices found") - device_id = 0 - - this._device_id = int(device_id) - - -def device_id() -> int: - initialize_cuda_context() - return this._device_id diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py deleted file mode 100644 index 59cfaf7154687fa3a971f2221f0cce2130ff1a4f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.arguments import * -from cutlass_cppgen.backend.c_types import * -from cutlass_cppgen.backend.compiler import ArtifactManager -from cutlass_cppgen.backend.conv2d_operation import * -from cutlass_cppgen.backend.epilogue import * -from cutlass_cppgen.backend.frontend import * -from cutlass_cppgen.backend.gemm_operation import * -from cutlass_cppgen.backend.library import * -from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool -from cutlass_cppgen.backend.operation import * -from cutlass_cppgen.backend.reduction_operation import * -from cutlass_cppgen.backend.type_hint import * -from cutlass_cppgen.backend.utils import * -from cutlass_cppgen.backend.utils.device import device_cc - -compiler = ArtifactManager() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py deleted file mode 100644 index b1b0656a89a8b0a42b864429810b74bc433582d4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py +++ /dev/null @@ -1,136 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from math import prod -from typing import Union - -from cutlass_cppgen.utils.lazy_import import lazy_import - -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -import numpy as np - -import cutlass_cppgen -from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend -from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper -from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor - - -class ArgumentBase: - """ - Base class for operation arguments - """ - - def __init__( - self, - A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", - B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", - C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", - D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", - **kwargs, - ) -> None: - # tensor_C can be interpreted as the bias with bias=True in keyword args - self.bias = kwargs.get("bias", False) - - self.stream = kwargs.get("stream", cuda.CUstream(0)) - - # RMM buffers used to track tensor lifetime - self.buffers = {} - # Host tensor to copy the computed result back - self.host_tensors = {} - - self.ptr_A = self.tensor_to_ptr(A, "A") - self.ptr_B = self.tensor_to_ptr(B, "B") - self.ptr_C = self.tensor_to_ptr(C, "C") - self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True) - if C is not None: - if not isinstance(C, cuda.CUdeviceptr): - self.tensor_c_numel = prod(C.shape) - - def tensor_to_ptr(self, tensor, name, is_output=False): - """ - Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python - For numpy.ndarray, it also remembers the host buffer for synchronization - """ - if tensor is None: - return cuda.CUdeviceptr(0) - if is_numpy_tensor(tensor): - if is_output: - assert name - self.buffers[name] = NumpyFrontend.argument(tensor, is_output) - if is_output: - self.host_tensors[name] = tensor - return self.buffers[name].ptr - elif is_torch_tensor(tensor): - return TorchFrontend.argument(tensor) - elif isinstance(tensor, cuda.CUdeviceptr): - return tensor - elif is_cupy_tensor(tensor): - return CupyFrontend.argument(tensor) - else: - raise TypeError("Unsupported Frontend. Only support numpy and torch") - - def sync(self, stream_sync=True): - if stream_sync: - (err,) = cudart.cudaDeviceSynchronize() - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - for key in self.host_tensors.keys(): - host_tensor = self.host_tensors[key] - (err,) = cuda.cuMemcpyDtoH( - host_tensor, - self.buffers[key].ptr, - host_tensor.size * host_tensor.itemsize, - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - self.free() - - def free(self): - """ - Frees allocated device-side memory - """ - # Free any device memory allocated manually - if not cutlass_cppgen.use_rmm: - for name, buf in self.buffers.items(): - if isinstance(buf, DevicePtrWrapper): - err, = cudart.cudaFree(buf.ptr) - if err != cudart.cudaError_t.cudaSuccess: - raise RuntimeError(f"cudaFree failed with error {err}") - - if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper): - err, = cudart.cudaFree(self.workspace_buffer.ptr) - if err != cudart.cudaError_t.cudaSuccess: - raise RuntimeError(f"cudaFree failed with error {err}") - del self.workspace_buffer diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py deleted file mode 100644 index 3f515aa38439e4b2e1392659d188cbe6a68e0481..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py +++ /dev/null @@ -1,625 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import ctypes - -from cutlass_library import ( - DataType, - KernelScheduleType, - TileSchedulerType -) -from cutlass_cppgen.backend.library import DataTypeSizeBytes - - -class GemmCoord_(ctypes.Structure): - _fields_ = [ - ("m", ctypes.c_int), - ("n", ctypes.c_int), - ("k", ctypes.c_int) - ] - - def __init__(self, m, n, k) -> None: - self.m = m - self.n = n - self.k = k - - -class GemmCoordBatched_(ctypes.Structure): - """ - Wrapper around a GemmCoord that also contains batch count. This is used for encoding - batched GEMM inputs to CUTLASS 3 GEMMs. - """ - - _fields_ = [ - ("m", ctypes.c_int), - ("n", ctypes.c_int), - ("k", ctypes.c_int), - ("batch_count", ctypes.c_int) - ] - - def __init__(self, gemm_coord, batch_count) -> None: - self.m = gemm_coord.m - self.n = gemm_coord.n - self.k = gemm_coord.k - self.batch_count = batch_count - - -class MatrixCoord_(ctypes.Structure): - _fields_ = [ - ("row", ctypes.c_int), - ("column", ctypes.c_int) - ] - - -class dim3_(ctypes.Structure): - _fields_ = [ - ("x", ctypes.c_int), - ("y", ctypes.c_int), - ("z", ctypes.c_int) - ] - - -class StrideBatched_(ctypes.Structure): - """ - CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The - variable dimensions represent the stride along non-unit-stride dimension of the row/column major - layout, and the batch stride. This structure encodes the two variable dimensions. - """ - _fields_ = [ - ("major_stride", ctypes.c_int64), - ("batch_stride", ctypes.c_int64) - ] - - - -class GenericMainloopArguments3x_(ctypes.Structure): - """ - Structure representing the superset of possible mainloop arguments. - This structure should not be passed to kernels directly, but, rather, - be used as an input to one of the more specific schedule arguments, which - will each select those arguments relevant to the particular schedule. - """ - _fields_ = [ - ("ptr_A", ctypes.c_void_p), - ("stride_A", StrideBatched_), - ("ptr_B", ctypes.c_void_p), - ("stride_B", StrideBatched_), - ("mma_promotion_interval", ctypes.c_int) - ] - - -class _PersistentTileSchedulerArguments(ctypes.Structure): - _fields_ = [ - ("max_swizzle_size", ctypes.c_int), - ("raster_order_option", ctypes.c_int), - ] - - -class _PersistentTileSchedulerStreamKArguments(ctypes.Structure): - _fields_ = [ - ("splits", ctypes.c_int), - ("max_swizzle_size", ctypes.c_int), - ("raster_order_option", ctypes.c_int), - ("reduction_mode", ctypes.c_int), - ("decomposition_mode", ctypes.c_int), - ] - - -def get_tile_scheduler_arguments_3x( - tile_scheduler: TileSchedulerType, - splits: int = 1): - max_swizzle_size = 1 - raster_order_option = 0 # Heuristic - if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]: - return _PersistentTileSchedulerArguments( - max_swizzle_size, - raster_order_option, - ) - elif tile_scheduler == TileSchedulerType.StreamK: - reduction_mode = 0 # Deterministic - decomposition_mode = 0 # Heuristic - return _PersistentTileSchedulerStreamKArguments( - splits, - max_swizzle_size, - raster_order_option, - reduction_mode, - decomposition_mode, - ) - - -def get_mainloop_arguments_3x( - kernel_schedule: KernelScheduleType, - element_A, - element_B, - alignment_A: int, - alignment_B: int) -> ctypes.Structure: - """ - Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. - - :param kernel_schedule: type of kernel schedule to be used in the mainloop - :type kernel_schedule: cutlass_library.KernelScheduleType - :param element_A: data type of operand A - :param element_B: data type of operand B - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - - :returns: ctypes structure to be used for the 3.x kernel's mainloop parameters - :rtype: ctypes.Structure - """ - class _MainloopArgumentsTma(ctypes.Structure): - _fields_ = [ - ("ptr_A", ctypes.c_void_p), - ("stride_A", StrideBatched_), - ("ptr_B", ctypes.c_void_p), - ("stride_B", StrideBatched_), - ("mma_promotion_interval", ctypes.c_int) - ] - - @staticmethod - def from_generic_mainloop_args(args: GenericMainloopArguments3x_): - return _MainloopArgumentsTma( - args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, - args.mma_promotion_interval - ) - - class _MainloopArgumentsMultistage(ctypes.Structure): - _fields_ = [ - ("ptr_A", ctypes.c_void_p), - ("stride_A", StrideBatched_), - ("ptr_B", ctypes.c_void_p), - ("stride_B", StrideBatched_), - ] - - @staticmethod - def from_generic_mainloop_args(args: GenericMainloopArguments3x_): - return _MainloopArgumentsMultistage( - args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, - ) - - # Currently all 3.x kernels (CpAsync and Tma) have the same argument structure. - # Should that become not the case, this is the place to return custom ctypes - # structures based on selected kernel schedule. - return _MainloopArgumentsTma - - -def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue): - if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt - else: - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - - if hasattr(epilogue_functor, "visitor"): - class _EpilogueArguments(ctypes.Structure): - _fields_ = [ - ("epilogue", _EpilogueOutputOpParams), - ("arg_C", epilogue_functor.arg_c_type), - ("arg_D", epilogue_functor.arg_d_type) - ] - - def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None: - self.epilogue = output_op - self.arg_C = epilogue_functor.arg_c_type(ptr_c) - self.arg_D = epilogue_functor.arg_d_type(ptr_d) - else: - class _EpilogueArguments(ctypes.Structure): - _fields_ = [ - ("epilogue", _EpilogueOutputOpParams), - ("ptr_C", ctypes.c_void_p), - ("stride_C", StrideBatched_), - ("ptr_D", ctypes.c_void_p), - ("stride_D", StrideBatched_), - ] - - class _HardwareInfo(ctypes.Structure): - _fields_ = [ - ("device_id", ctypes.c_int), - ("sm_count", ctypes.c_int), - ("max_active_clusters", ctypes.c_int), - ("cluster_shape", dim3_), - ("cluster_shape_fallback", dim3_), - ] - - class _GemmArguments(ctypes.Structure): - _fields_ = [ - ("mode", ctypes.c_int), - ("problem_size", GemmCoordBatched_), - ("mainloop", mainloop_arguments), - ("epilogue", _EpilogueArguments), - ("hw_info", _HardwareInfo), - ("scheduler", type(scheduler_args)), - ] - - return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo - - -def get_gemm_arguments(epilogue_functor): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - - class _GemmArguments(ctypes.Structure): - _fields_ = [ - # Arguments from UniversalArgumentsBase - ("mode", ctypes.c_int), - ("problem_size", GemmCoord_), - ("batch_count", ctypes.c_int), - ("batch_stride_D", ctypes.c_longlong), - # Remaining arguments - ("epilogue", _EpilogueOutputOpParams), - ("ptr_A", ctypes.c_void_p), - ("ptr_B", ctypes.c_void_p), - ("ptr_C", ctypes.c_void_p), - ("ptr_D", ctypes.c_void_p), - ("batch_stride_A", ctypes.c_longlong), - ("batch_stride_B", ctypes.c_longlong), - ("batch_stride_C", ctypes.c_longlong), - ("stride_a", ctypes.c_longlong), - ("stride_b", ctypes.c_longlong), - ("stride_c", ctypes.c_longlong), - ("stride_d", ctypes.c_longlong), - ("lda", ctypes.c_longlong), - ("ldb", ctypes.c_longlong), - ("ldc", ctypes.c_longlong), - ("ldd", ctypes.c_longlong), - ("ptr_gather_A_indices", ctypes.c_void_p), - ("ptr_gather_B_indices", ctypes.c_void_p), - ("ptr_scatter_D_indices", ctypes.c_void_p) - ] - - return _GemmArguments, _EpilogueOutputOpParams - - -def get_gemm_arguments_streamk(epilogue_functor): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - - class _GemmArguments(ctypes.Structure): - _fields_ = [ - ("mode", ctypes.c_int), - ("problem_size", GemmCoord_), - ("batch_count", ctypes.c_int), - ("epilogue", _EpilogueOutputOpParams), - ("ptr_A", ctypes.c_void_p), - ("ptr_B", ctypes.c_void_p), - ("ptr_C", ctypes.c_void_p), - ("ptr_D", ctypes.c_void_p), - ("batch_stride_A", ctypes.c_longlong), - ("batch_stride_B", ctypes.c_longlong), - ("batch_stride_C", ctypes.c_longlong), - ("batch_stride_D", ctypes.c_longlong), - ("stride_a", ctypes.c_longlong), - ("stride_b", ctypes.c_longlong), - ("stride_c", ctypes.c_longlong), - ("stride_d", ctypes.c_longlong), - ("lda", ctypes.c_longlong), - ("ldb", ctypes.c_longlong), - ("ldc", ctypes.c_longlong), - ("ldd", ctypes.c_longlong), - ("avail_sms", ctypes.c_int) - ] - - return _GemmArguments, _EpilogueOutputOpParams - - -########################################################################################### -# GEMM Grouped -########################################################################################### - - -def get_gemm_grouped_arguments(epilogue_functor): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - - class _GEMMGroupedArguments(ctypes.Structure): - _fields_ = [ - ("problem_sizes", ctypes.c_void_p), - ("problem_count", ctypes.c_int), - ("threadblock_count", ctypes.c_int), - ("output_op", _EpilogueOutputOpParams), - ("ptr_A", ctypes.c_void_p), - ("ptr_B", ctypes.c_void_p), - ("ptr_C", ctypes.c_void_p), - ("ptr_D", ctypes.c_void_p), - ("lda", ctypes.c_void_p), - ("ldb", ctypes.c_void_p), - ("ldc", ctypes.c_void_p), - ("ldd", ctypes.c_void_p), - ("host_problem_sizes", ctypes.c_void_p) - ] - - return _GEMMGroupedArguments, _EpilogueOutputOpParams - - -############################################################################################ -# Convolution2D -############################################################################################ - - -class Conv2DProblemSize_(ctypes.Structure): - _fields_ = [ - ("N", ctypes.c_int), - ("H", ctypes.c_int), - ("W", ctypes.c_int), - ("C", ctypes.c_int), - ("P", ctypes.c_int), - ("Q", ctypes.c_int), - ("K", ctypes.c_int), - ("R", ctypes.c_int), - ("S", ctypes.c_int), - ("pad_h", ctypes.c_int), - ("pad_w", ctypes.c_int), - ("stride_h", ctypes.c_int), - ("stride_w", ctypes.c_int), - ("dilation_h", ctypes.c_int), - ("dilation_w", ctypes.c_int), - ("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1 - ("split_k_slices", ctypes.c_int), - ("groups", ctypes.c_int) - ] - - def __init__(self, problem_size) -> None: - for field_name, _ in self._fields_: - setattr(self, field_name, getattr(problem_size, field_name)) - - -class Layout4D(ctypes.Structure): - _fields_ = [("stride", ctypes.c_int * 3)] - - def __init__(self, tensor_ref): - stride = tensor_ref.stride() - setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2))) - - -class TensorRef_(ctypes.Structure): - _fields_ = [ - ("ptr", ctypes.c_void_p), - ("layout", Layout4D) - ] - - def __init__(self, tensor_ref): - setattr(self, "ptr", tensor_ref.data()) - setattr(self, "layout", Layout4D(tensor_ref.layout())) - - -class TensorRef2D_(ctypes.Structure): - _fields_ = [ - ("ptr", ctypes.c_void_p), - ("stride", ctypes.c_int) - ] - - -def get_conv2d_arguments(epilogue_functor): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - - class _Conv2dArguments(ctypes.Structure): - _fields_ = [ - ("conv_kind", ctypes.c_int), - ("problem_size", Conv2DProblemSize_), - ("ptr_A", ctypes.c_void_p), - ("ptr_B", ctypes.c_void_p), - ("ptr_C", ctypes.c_void_p), - ("ptr_D", ctypes.c_void_p), - ("tensor_C_numel", ctypes.c_int), - ("output_op", _EpilogueOutputOpParams), - ("split_k_mode", ctypes.c_int) - ] - - return _Conv2dArguments, _EpilogueOutputOpParams - - -############################################################################################ -# Reduction -############################################################################################ - - -def get_reduction_params(epilogue_functor): - _EpilogueOutputParams = epilogue_functor.epilogue_type - - class _ReductionParams(ctypes.Structure): - _fields_ = [ - ("problem_size", MatrixCoord_), - ("partitions", ctypes.c_int), - ("partition_stride", ctypes.c_longlong), - ("workspace", TensorRef2D_), - ("destination", TensorRef2D_), - ("source", TensorRef2D_), - ("output_op", _EpilogueOutputParams), - ] - - return _ReductionParams, _EpilogueOutputParams - - -########################################################################################### -# Epilogue Visitor Type Factory -########################################################################################### - -class Empty(ctypes.Structure): - _fields_ = [] - - def __init__(self, *arg) -> None: - pass - -class EmptyByte(ctypes.Structure): - _fields_ = [ - ("byte", ctypes.c_byte) - ] - - def __init__(self, *arg) -> None: - pass - -class EBO: - def __init__(self, index: int, type) -> None: - self.index = index - self.type = type - - def __eq__(self, other) -> bool: - if isinstance(other, EBO): - return self.index == other.index and self.type == other.type - return False - - def __hash__(self) -> int: - return hash((self.index, self.type)) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self) -> str: - return f"<{self.index}, {self.type}>" - - -def tuple_factory_(input_tuple, dtype, constants=[0,1]): - """ - The factory function generating cute::Tuple with input tuple - :param input_tuple: the input tuple - :type input_tuple: tuple - :param dtype: the data type for non-constant values - :type dtype: str, "int32_t", "int", "int64_t" - :param constant: the values that will be treated as constants - :type constant: list[int] - - :return: ctype structure representing the cute::Tuple - :return: the empty base classes of the tuple - """ - - # The empty base classes of the current tuple - empty_bases = [] - # The first non empty base class - first_non_empty_base = None - # The ctype fields of the current tuple - ctype_fields = [] - - for idx, entry in enumerate(input_tuple): - # For nested tuples - if isinstance(entry, tuple): - sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants) - if ctypes.sizeof(sub_tuple_ctype) == 0: - # The empty tuple base class is also an empty EBO - empty_bases.append(EBO(idx, entry)) - else: - if first_non_empty_base is None: - first_non_empty_base = sub_empty_bases - ctype_fields.append((f"entry_{idx}", sub_tuple_ctype)) - else: - if entry in constants: - empty_bases.append(EBO(idx, entry)) - ctype_fields.append((f"entry_{idx}", Empty)) - else: - ctype_fields.append((f"entry_{idx}", dtype)) - if first_non_empty_base is None: - first_non_empty_base = [] - - # Create the ctype tuple - class TupleType(ctypes.Structure): - _fields_ = ctype_fields - - def __init__(self, args) -> None: - fields = self._fields_ - - assert len(fields) == len(args) - for field, arg in zip(fields, args): - name = field[0] - field_type = field[1] - setattr(self, name, field_type(arg)) - - return TupleType, empty_bases - -def tuple_factory(input_tuple, dtype: str, constants=[0,1]): - """ - The factory function generating cute::Tuple with input tuple - :param input_tuple: the input tuple - :type input_tuple: tuple - :param dtype: the data type for non-constant values - :type dtype: str, "int32_t", "int", "int64_t" - :param constant: the values that will be treated as constants - :type constant: list[int] - - :return: ctype structure representing the cute::Tuple - :return: the empty base classes of the tuple - """ - # Step 1: convert the dtype - if dtype == "int64_t": - dtype = ctypes.c_longlong - elif dtype in ["int", "int32_t"]: - dtype = ctypes.c_int32 - else: - raise NotImplementedError(f"Type {dtype} is not supported") - - tuple_type, _ = tuple_factory_(input_tuple, dtype, constants) - - if ctypes.sizeof(tuple_type) == 0: - return EmptyByte - return tuple_type - - -def visitor_factory(node_types, node_names): - """ - Creates the argument type of epilogue visitor type - - :param node_types: list of argument types under ctypes - :param node_names: list of argument names under str - - :return: tuple type in ctypes.Structure - """ - ctypes_field = [] - # Struct is used when number of nodes < 4 - # Because the Sm90VisitorImplBase has specification up to 4 nodes - # in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp` - if len(node_types) <= 4: - for idx, node_type in enumerate(node_types): - if ctypes.sizeof(node_type) == 0: - # Special case for empty struct - # 1 byte placeholder is used for correct alignment - ctypes_field.append((node_names[idx], ctypes.c_byte)) - else: - ctypes_field.append((node_names[idx], node_type)) - - class VisitorType(ctypes.Structure): - _fields_ = ctypes_field - - def __init__(self, kwargs) -> None: - for field in self._fields_: - fname, ftype = field - if ftype != ctypes.c_byte: - setattr(self, fname, ftype(kwargs)) - - # For cases with more than 4 nodes, tuple is used - else: - for idx, node_type in enumerate(node_types): - ctypes_field.append((node_names[idx], node_type)) - - class VisitorType(ctypes.Structure): - _fields_ = ctypes_field - - def __init__(self, kwargs) -> None: - for field in self._fields_: - fname, ftype = field - setattr(self, fname, ftype(kwargs)) - - return VisitorType diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py deleted file mode 100644 index 0b66ce8a2402a109e2da00613e7255760685855c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py +++ /dev/null @@ -1,462 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import ctypes -import json -import os -import sqlite3 -import subprocess -import tempfile - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -nvrtc = lazy_import("cuda.nvrtc") -from cutlass_library import SubstituteTemplate - -import cutlass_cppgen -from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger -from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal -from cutlass_cppgen.backend.library import ApiVersion -from cutlass_cppgen.backend.utils.device import device_cc - -IncludeTemplate = r"""#include "${include}" -""" - - -def compile_with_nvcc(cmd, source, error_file): - succeed = True - try: - subprocess.check_output(cmd, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - error_message = e.output.decode() - with open(error_file, "w") as error_out: - error_log = "Compilation error for the following kernel: \n" - error_log += source - error_log += "\nError Message:\n" - error_log += error_message - error_out.write(error_log) - succeed = False - if not succeed: - # Print the error log to stdout if log level is set to warning or higher - # verbosity. Otherwise, simply point to the error log file. - logger.warning(error_log) - raise Exception(f"Invalid Kernel. See '{error_file}' for details.") - - -class CompilationOptions: - """ - Compilation options. - """ - - def __init__(self, flags, arch, include_paths=[]): - self.includes = [] - self.include_paths = include_paths - self.flags = flags - self.arch = arch - - def get_str(self): - opts = [] - for flag in self.flags: - opts.append(flag) - - for incl in self.include_paths: - opts.append(f"--include-path={incl}") - - arch_flag = f"-arch=sm_{self.arch}" - if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: - arch_flag += "a" - opts.append(arch_flag) - - return " ".join(opts) - - def get(self): - options = [] - - for flag in self.flags: - options.append(bytes(str.encode(flag))) - - for incl in self.include_paths: - options.append(bytes(str.encode(f" --include-path={incl}"))) - - arch_flag = f" -arch=sm_{self.arch}" - if self.arch in [90, 100, 101, 103, 120, 121]: - arch_flag += "a" - - options.append(bytes(str.encode(arch_flag))) - - return options - - -def convertToBinaryData(filename): - with open(filename, "rb") as file: - blobData = file.read() - return blobData - - -def CDLLBin(host_binary): - tempfile.tempdir = "./" - temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True) - with open(temp_so.name, "wb") as file: - file.write(host_binary) - host_lib = ctypes.CDLL(temp_so.name) - return host_lib - - -class ArtifactManager: - """ - Artifact manager - """ - - def __init__(self) -> None: - connection = sqlite3.connect(CACHE_FILE) - cursor = connection.cursor() - # Create the table if it does not already exist - sqlite_create_table_query = """ - CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE, - cubin BLOB NOT NULL, - hostbin BLOB NOT NULL, - op_name TEXT NOT NULL, - op_attrs TEXT NOT NULL) - """ - cursor.execute(sqlite_create_table_query) - connection.commit() - cursor.close() - - self._nvrtc_compile_options = ["-std=c++17", "-default-device"] - self._nvcc_compile_options = [ - "-std=c++17", - "--expt-relaxed-constexpr", - "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", - ] - self.nvcc() - self.compiled_cache_device = {} - self.compiled_cache_host = {} - - def nvrtc(self): - self.backend = "nvrtc" - self.default_compile_options = self._nvrtc_compile_options - - def nvcc(self): - self.backend = "nvcc" - self.default_compile_options = self._nvcc_compile_options - - def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): - connection = sqlite3.connect(CACHE_FILE) - cursor = connection.cursor() - sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" - - hostbin = convertToBinaryData(hostfile) - - data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) - - cursor.execute(sqlite_insert_blob_query, data_tuple) - connection.commit() - cursor.close() - - def load_operation(self, op_key, extra_funcs): - connection = sqlite3.connect(CACHE_FILE) - cursor = connection.cursor() - sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" - cursor.execute(sqlite_fetch_blob_query, (op_key,)) - record = cursor.fetchall() - if len(record) == 0: - return False - for row in record: - key, cubin_image, host_binary, operation_name, op_attr = row - op_attr = json.loads(op_attr) - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Cuda Error: {}".format(err)) - - err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) - self.compiled_cache_device[key] = kernel - - compiled_host_fns = {} - host_lib = CDLLBin(host_binary) - - func_name = operation_name + "_get_params" - func = getattr(host_lib, func_name) - func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) - compiled_host_fns["get_args"] = func - - func_name = operation_name + "_shared_memory_size" - func = getattr(host_lib, func_name) - compiled_host_fns["shared_memory_capacity"] = func() - - for attr in op_attr: - if isinstance(attr, str): - func_name = operation_name + "_" + attr - func = getattr(host_lib, func_name) - - # Set the return type of the function - if attr in extra_funcs and extra_funcs[attr] != None: - func.restype = extra_funcs[attr] - - compiled_host_fns[attr] = func - - self.compiled_cache_host[key] = compiled_host_fns - return True - - def emit_compile_(self, operation_list, compilation_options, host_compilation_options): - """ - Compile a list of kernels and store them into database - """ - source_buffer_device = "" - source_buffer_host = "" - # 1. include - includes = [] - for operation in operation_list: - for incl in operation.emitter.includes: - if incl not in includes: - includes.append(incl) - - includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes - for incl in includes: - source_buffer_device += SubstituteTemplate( - IncludeTemplate, - {"include": incl}, - ) - - for incl in includes_host: - source_buffer_host += SubstituteTemplate( - IncludeTemplate, - {"include": incl}, - ) - - # 2. Operations - for operation in operation_list: - source_buffer_device += operation.emit() - source_buffer_host += operation.emit() - values = { - "operation_name": operation.name(), - "operation_suffix": operation.emitter.operation_suffix, - } - source_buffer_device += SubstituteTemplate( - operation.KernelTemplate, - values, - ) - source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) - - if self.backend == "nvrtc": - # 3. compile - err, program = nvrtc.nvrtcCreateProgram( - str.encode(source_buffer_device), - bytes(str.encode("module.cu")), - 0, [], []) - - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("NVRTC Error: {}".format(err)) - - # Compile program - options = compilation_options.get() - - err, = nvrtc.nvrtcCompileProgram(program, len(options), options) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - error_string = "NVRTC Error: {}\n".format(err) - - # Get log from compilation - err, logSize = nvrtc.nvrtcGetProgramLogSize(program) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("NVRTC Error: {}".format(err)) - - log = b" " * logSize - err, = nvrtc.nvrtcGetProgramLog(program, log) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("NVRTC Error: {}".format(err)) - - raise RuntimeError(error_string + log.decode() + source_buffer_device) - - # Get data from compilation - err, dataSize = nvrtc.nvrtcGetCUBINSize(program) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("NVRTC Error: {}".format(err)) - - cubin_image = b" " * dataSize - (err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("NVRTC Error: {}".format(err)) - - else: # with nvcc backend - # emit code - tempfile.tempdir = "./" - temp_cu = tempfile.NamedTemporaryFile( - prefix="kernel", suffix=".cu", delete=True) - temp_cubin = tempfile.NamedTemporaryFile( - prefix="kernel", suffix=".cubin", delete=True) - with open(temp_cu.name, "w") as file: - file.write(source_buffer_device) - - # compile with nvcc - cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" - values = { - "cuda_install_path": cuda_install_path(), - "options": compilation_options.get_str(), - "srcfile": temp_cu.name, - "tarfile": temp_cubin.name, - } - cmd = SubstituteTemplate(cmd_template, values) - compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt") - - # load the cubin image - with open(temp_cubin.name, "rb") as file: - cubin_image = file.read() - - tempfile.tempdir = "./" - temp_src = tempfile.NamedTemporaryFile( - prefix="host_src", suffix=".cu", delete=True) - - # Write the host source - with open(temp_src.name, "w") as outfile: - outfile.write(source_buffer_host) - - temp_dst = tempfile.NamedTemporaryFile( - prefix="host_func", suffix=".so", delete=True) - - # Set up host compilation arguments - cmd = [] - cmd.append(f"{cuda_install_path()}/bin/nvcc") - cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"]) - cmd.extend(host_compilation_options.get_str().split(" ")) - cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"]) - - # Comile and load the library - compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt") - host_lib = ctypes.CDLL(temp_dst.name) - - return cubin_image, host_lib, temp_dst - - def add_module(self, operations, compile_options=None, bypass_cache=False): - """ - Insert a new compiled device module - """ - include_paths = [ - cuda_install_path() + "/include", - CUTLASS_PATH + "/include", - CUTLASS_PATH + "/tools/util/include", - CUTLASS_PATH + "/python/cutlass/cpp/include", - ] - - cutlass_cppgen.initialize_cuda_context() - arch = device_cc() - - host_compile_options = CompilationOptions( - self._nvcc_compile_options, arch, include_paths) - if compile_options is None: - compile_options = CompilationOptions( - self.default_compile_options, arch, include_paths) - # save the cubin - operation_key = [] - operation_list = [] - for operation in operations: - # step 1: get kernel string as key - key = operation.rt_module.emit() + operation.procedural_name() + self.backend - # step 1: check if the operation is in cache - compiled_kernel = self.compiled_cache_device.get(key) - - if compiled_kernel is None and not bypass_cache: - hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) - if hit: - compiled_kernel = self.compiled_cache_device.get(key) - assert compiled_kernel is not None - if compiled_kernel is not None: - operation.rt_module.kernel = compiled_kernel - compiled_host_fns = self.compiled_cache_host.get(key) - assert compiled_host_fns is not None - for key in compiled_host_fns.keys(): - setattr(operation.rt_module, key, compiled_host_fns[key]) - operation.rt_module.initialize() - else: - operation_list.append(operation.rt_module) - operation_key.append(key) - - if len(operation_list) > 0: - cubin_image, host_lib, host_file = self.emit_compile_( - operation_list, compile_options, host_compile_options) - - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Cuda Error: {}".format(err)) - - operation_name = [] - operation_attr = [] - for operation, key in zip(operation_list, operation_key): - # get device kernels - err, operation.kernel = cuda.cuModuleGetFunction( - module, - bytes(str.encode(operation.name())) - ) - operation_name.append(operation.name()) - self.compiled_cache_device[key] = operation.kernel - # get host functions - compiled_host_fns = {} - op_attr = [] - - # get param size - func_name = operation.name() + "_get_param_size" - func = getattr(host_lib, func_name) - param_size = func() - - func_name = operation.name() + "_get_params" - func = getattr(host_lib, func_name) - func.argtype = operation.argtype - func.restype = ctypes.POINTER(ctypes.c_char * param_size) - setattr(operation, "get_args", func) - compiled_host_fns["get_args"] = func - - # set shared memory size - func_name = operation.name() + "_shared_memory_size" - func = getattr(host_lib, func_name) - setattr(operation, "shared_memory_capacity", func()) - compiled_host_fns["shared_memory_capacity"] = func() - # set the maximum dynamic shared size - operation.initialize() - - # get extra functions - op_attr.append(param_size) - - if hasattr(operation, "extra_funcs"): - for suffix, ret_type in operation.extra_funcs.items(): - func_name = operation.name() + "_" + suffix - func = getattr(host_lib, func_name) - if ret_type is not None: - func.restype = ret_type - setattr(operation, suffix, func) - compiled_host_fns[suffix] = func - op_attr.append(suffix) - - operation_attr.append(op_attr) - self.compiled_cache_host[key] = compiled_host_fns - - for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr): - self.insert_operation( - key, cubin_image, host_file.name, operation_name, operation_attr) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py deleted file mode 100644 index 03679c434e1a63e9d1f9f2d1571dacedcf6e1470..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py +++ /dev/null @@ -1,700 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -from __future__ import annotations - -import ctypes -from typing import Union - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -from cutlass_library import SubstituteTemplate -import numpy as np - -from cutlass_library import ( - ConvKindNames, - ConvKindTag, - DataTypeNames, - DataTypeSize, - DataTypeTag, - IteratorAlgorithmNames, - IteratorAlgorithmTag, - LayoutTag, - LayoutType, - MathOperation, - MathOperationTag, - OpcodeClass, - OpcodeClassNames, - OpcodeClassTag, - OperationKind, - ShortDataTypeNames, - ShortLayoutTypeNames, - SplitKMode, - StrideSupport, - StrideSupportTag, - SwizzlingFunctor, - SwizzlingFunctorTag, - get_complex_from_real, -) - -from cutlass_cppgen.backend.arguments import ArgumentBase -from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments -from cutlass_cppgen.backend.library import ( - EmissionType, - TensorDescription, - TileDescription, -) -from cutlass_cppgen.backend.memory_manager import device_mem_alloc -from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass_cppgen.backend.utils.device import to_device_ptr -from cutlass_cppgen.shape import GemmCoord - - -class Conv2dArguments(ArgumentBase): - """ - Argument wrapper for Conv2d. It encodes problem information and - user-provide tensors into the kernel's argument. - - :param operation: the Conv2d operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.Conv2dOperation` - :param problem_size: the Conv2d problem size - :type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize` - :param A: tensor A - :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param B: tensor B - :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param C: tensor C - :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param D: tensor D - :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial - :type split_k_mode: cutlass_library.library.SplitKMode, optional - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - """ - - def __init__(self, operation, problem_size, A, B, C, D, - split_k_mode=SplitKMode.Serial, **kwargs, ) -> None: - self.operation = operation - self.conv_kind = operation.conv_kind - self.layout_A = operation.A.layout - self.layout_B = operation.B.layout - self.layout_C = operation.C.layout - - self.element_A = operation.A.element - self.element_B = operation.B.element - self.element_C = operation.C.element - - if self.layout_C == LayoutType.TensorNC32HW32: - raise Exception("Layout type TensorNC32HW32 is not currently supported") - - super().__init__(A, B, C, D, **kwargs) - - if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1: - self.split_k_mode = split_k_mode - self.split_k_slices = kwargs["split_k_slices"] - else: - self.split_k_mode = SplitKMode.Serial - self.split_k_slices = 1 - - if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel: - self.output_op = kwargs["output_op"] - else: - self.output_op = self.operation.epilogue_type(1.0, 0.0) - - self.problem_size = problem_size - self.problem_size.split_k_slices = self.split_k_slices - - self.initialize() - - def get_arguments(self): - tc_numel = -1 - if hasattr(self, "tensor_c_numel"): - tc_numel = self.tensor_c_numel - - self.c_arguments = self.operation.argument_type( - int(self.conv_kind), - self.problem_size.ctype, - int(to_device_ptr(self.ptr_A)), - int(to_device_ptr(self.ptr_B)), - int(to_device_ptr(self.ptr_C)), - int(to_device_ptr(self.ptr_D)), - tc_numel, - self.output_op, - int(self.split_k_mode) - ) - - def initialize(self): - self.launch_config = self.operation.rt_module.plan(self) - - self.get_arguments() - - # Allocate and initialize device workspace - device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments) - if device_workspace_size > 0: - self.workspace_buffer = device_mem_alloc(device_workspace_size) - workspace_ptr = self.workspace_buffer.ptr - err, = cuda.cuMemsetD32( - workspace_ptr, 0, device_workspace_size // 4) - else: - workspace_ptr = None - - self.semaphore = 0 - if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel: - self.ptr_D = workspace_ptr - # Reset arguments now that ptr_D has been updated - self.get_arguments() - elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial: - self.semaphore = workspace_ptr - - params_ = self.operation.rt_module.get_args( - self.c_arguments, ctypes.c_void_p(int(self.semaphore))) - self.host_workspace = bytearray(params_.contents) - self.device_workspace = None - - def sync(self): - """ - Synchronize the arguments. If the input tensor is in host, - copy it from device to host. - """ - return super().sync() - - -class Conv2dRT(ExecutableOperation): - """ - Conv2dRT manages the CUTLASS runtime components - """ - - KernelTemplate = r""" -extern "C" -__global__ void -${operation_name}(${operation_name}${operation_suffix}::Params params) { - - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - ${operation_name}${operation_suffix}::SharedStorage *shared_storage = - reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); - - ${operation_name}${operation_suffix} op; - - op(params, *shared_storage); -} - """ - - HostTemplate = r""" -extern "C" { - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); - } - - using ElementA = typename ${operation_name}_base::ElementA; - using ElementB = typename ${operation_name}_base::ElementB; - using ElementC = typename ${operation_name}_base::ElementC; - using LayoutA = typename ${operation_name}_base::LayoutA; - using LayoutB = typename ${operation_name}_base::LayoutB; - using LayoutC = typename ${operation_name}_base::LayoutC; - using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp; - - struct ${operation_name}_TemporaryArgs { - int conv_kind; - cutlass::conv::Conv2dProblemSize problem_size; - ElementA* ptr_A; - ElementB* ptr_B; - ElementC* ptr_C; - ElementC* ptr_D; - int tensor_c_numel; - typename EpilogueOutputOp::Params epilogue_params; - int split_k_mode; - }; - - typename ${operation_name}${operation_suffix}::Arguments - construct_arguments(${operation_name}_TemporaryArgs args) { - cutlass::conv::Operator conv_operator = static_cast(args.conv_kind); - auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size); - auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size); - auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); - auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); - - auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3); - if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) { - // C is interpreted as bias - tc_C = {0, 0, 0, 0}; - } - - cutlass::TensorRef tref_A(args.ptr_A, LayoutA::packed(tc_A)); - cutlass::TensorRef tref_B(args.ptr_B, LayoutB::packed(tc_B)); - cutlass::TensorRef tref_C(args.ptr_C, LayoutC::packed(tc_C)); - cutlass::TensorRef tref_D(args.ptr_D, LayoutC::packed(tc_D)); - - return { - args.problem_size, - tref_A, - tref_B, - tref_C, - tref_D, - args.epilogue_params, - static_cast(args.split_k_mode) - }; - } - - // Get the params as byte array - char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) { - auto arguments = construct_arguments(args); - typename ${operation_name}${operation_suffix}::Params* params; - params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore); - - char *bytes = ((char*)(params)); - char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; - for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) - output[i] = bytes[i]; - - return output; - } - - dim3 ${operation_name}_get_grid_shape( - int conv_kind, - cutlass::conv::Conv2dProblemSize problem_size, - cutlass::gemm::GemmCoord tile_size, - int split_k_slices - ) { - - using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle; - auto tiled_shape = Swizzle::get_tiled_shape( - static_cast(conv_kind), - problem_size, - tile_size, - split_k_slices); - - return Swizzle::get_grid_shape(tiled_shape); - } - - size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) { - auto arguments = construct_arguments(args); - - // Temporarily define device::-level Conv2d so that we can call get_workspace_size - using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>; - return DeviceConv::get_workspace_size(arguments); - } -} - - """ - - def __init__(self, operation: "Conv2dOperation"): - super().__init__(operation) - self.extra_funcs = { - "get_grid_shape": dim3_, - "get_workspace_size": ctypes.c_uint64 - } - self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor) - self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p] - self.conv_kind = operation.conv_kind - - self.operation: Conv2dOperation = operation - - self.emitter = EmitConv2dInstance("_type") - - self.threads = operation.tile_description.num_threads - - self.swizzle_functor = operation.swizzling_functor - - def emit(self): - return self.emitter.emit(self.operation) - - def plan(self, arguments: Conv2dArguments): - tile_size = GemmCoord( - self.operation.tile_description.threadblock_shape[0], - self.operation.tile_description.threadblock_shape[1], - self.operation.tile_description.threadblock_shape[2], - ) - - grid = self.get_grid_shape( - int(self.conv_kind), - arguments.problem_size.ctype, - tile_size.ctype, - arguments.split_k_slices - ) - - return LaunchConfiguration( - [grid.x, grid.y, grid.z], [self.threads, 1, 1], - self.shared_memory_capacity) - - def initialize(self): - err, = cuda.cuFuncSetAttribute( - self.kernel, - attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - value=self.shared_memory_capacity) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error: {err}") - - -class Conv2dOperation: - """ - CUTLASS Conv2d operation description. - - :param conv_kind: convolution operator - :type conv_kind: :class:`cutlass_library.library.ConvKind` - - :param iterator_algorithm: Selects among several implementation - variants trading off performance with simplicity - :type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm` - - :param arch: GPU compute capability (sm_xx) - :type arch: int - - :param tile_description: tile description - :type tile_description: :class:`cutlass_cppgen.backend.TileDescription` - - :param A: tensor A description - :type A: :class:`cutlass_cppgen.backend.TensorDescription` - - :param B: tensor B description - :type B: :class:`cutlass_cppgen.backend.TensorDescription` - - :param C: tensor C description - :type C: :class:`cutlass_cppgen.backend.TensorDescription` - - :param D: tensor D description - :type D: :class:`cutlass_cppgen.backend.TensorDescription` - - :param element_epilogue: element type for computation in epilogue \ - :type element_epilogue: cutlass_library.library.DataType - - :param stride_support: distinguish among partial specializations that \ - accelerate certain problems where convolution stride is unit \ - :type stride_support: :class:`cutlass_library.library.StrideSupport` - - :param epilogue_functor: convolution epilogue functor - :type epilogue_functor: :class:`EpilogueFunctor` - - :param swizzling_functor: threadblock swizzling functor - """ - def __init__( - self, - conv_kind, - iterator_algorithm, - arch: int, - tile_description: TileDescription, - A: TensorDescription, - B: TensorDescription, - C: TensorDescription, - stride_support, - epilogue_functor, - swizzling_functor=SwizzlingFunctor.Identity1, - emission_type=EmissionType.Kernel, - **kwargs - ): - self.operation_kind: OperationKind = OperationKind.Conv2d - self.arch: int = arch - self.tile_description: TileDescription = tile_description - self.conv_kind = conv_kind - self.A: TensorDescription = A - self.B: TensorDescription = B - self.C: TensorDescription = C - self.epilogue_functor = epilogue_functor - self.iterator_algorithm = iterator_algorithm - self.stride_support = stride_support - self.swizzling_functor = swizzling_functor - - self.emission_type = emission_type - - self.rt_module: Conv2dRT = Conv2dRT(self) - self.argument_type = self.rt_module.argument_type - self.epilogue_type = self.rt_module.epilogue_type - - def run(self, arguments: Conv2dArguments) -> cuda.CUresult: - """ - Launch the cuda kernel with input arguments - - :param arguments: conv2d arguments - :type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments` - """ - - # launch the kernel - err = self.rt_module.run( - arguments.host_workspace, - arguments.device_workspace, - arguments.launch_config, - arguments.stream - ) - - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {err}") - - return err - - # - # Get function name - # - - def procedural_name(self): - """The full procedural name indicates architecture, extended name, tile size, and layout.""" - return self.configuration_name() - - def configuration_name(self): - """The full procedural name indicates architecture, extended name, tile size, and layout.""" - - opcode_class_name = OpcodeClassNames[ - self.tile_description.math_instruction.opcode_class - ] - - threadblock = "%dx%d_%dx%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - self.tile_description.stages, - ) - - if self.stride_support == StrideSupport.Unity: - configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" - else: - configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" - - return SubstituteTemplate( - configuration_name, - { - "arch": str(self.arch), - "opcode_class": opcode_class_name, - "extended_name": self.extended_name(), - "threadblock": threadblock, - "layout": self.layout_name(), - "alignment": "%d" % self.A.alignment - }, - ) - - def extended_name(self): - """Append data types if they differ from compute type.""" - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - "element_a": DataTypeNames[self.A.element], - "element_c": DataTypeNames[self.C.element], - "core_name": self.core_name(), - }) - - return extended_name - - def layout_name(self): - return "%s" % (ShortLayoutTypeNames[self.A.layout]) - - def core_name(self): - """The basic operation kind is prefixed with a letter indicating the accumulation type.""" - - intermediate_type = "" - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: - inst_shape = "%dx%dx%d" % tuple( - self.tile_description.math_instruction.instruction_shape) - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.accumulator_type(): - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - else: - inst_shape = "" - - return "%s%s%s%s_%s" % ( - ShortDataTypeNames[self.accumulator_type()], - inst_shape, - intermediate_type, - ConvKindNames[self.conv_kind], - IteratorAlgorithmNames[self.iterator_algorithm] - ) - - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - def device_op(self): - """ - Returns a new Conv2dOperation object that is constructed with emission type - ``EmissionType.Device``. - - :return: operation ready for device-level code emission - :rtype: Conv2dOperation - """ - return Conv2dOperation( - self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description, - self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor, - emission_type=EmissionType.Device) - - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - - -class EmitConv2dInstance: - def __init__(self, operation_suffix=""): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/conv/kernel/default_conv2d_fprop.h", - "cutlass/conv/kernel/default_conv2d_dgrad.h", - "cutlass/conv/kernel/default_conv2d_wgrad.h", - "cutlass/conv/device/implicit_gemm_convolution.h" - ] - self.template = """ -// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" -using ${operation_name}_base = -typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< - ${element_a}, - ${layout_a}, - ${element_b}, - ${layout_b}, - ${element_c}, - ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operator}, - ${iterator_algorithm}, - ${stride_support}, - ${align_a}, - ${align_b} ->::Kernel; - -struct ${operation_name}${operation_suffix}: - public ${operation_name}_base { }; - -""" - - self.template_device = """ -// Conv2d operation ${operation_name} - -using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< - ${element_a}, - ${layout_a}, - ${element_b}, - ${layout_b}, - ${element_c}, - ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operator}, - ${iterator_algorithm}, - ${stride_support}, - ${align_a}, - ${align_b} ->::Kernel; - -using DeviceKernel = - typename cutlass::conv::device::ImplicitGemmConvolution; -""" - - def emit(self, operation): - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / - operation.tile_description.warp_count[idx]) for idx in range(3)] - - epilogue_vector_length = int(min( - operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - "operation_name": operation.procedural_name(), - "operation_suffix": self.operation_suffix, - "conv_kind": ConvKindTag[operation.conv_kind], - "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(), - "element_a": DataTypeTag[operation.A.element], - "layout_a": LayoutTag[operation.A.layout], - "element_b": DataTypeTag[operation.B.element], - "layout_b": LayoutTag[operation.B.layout], - "element_c": DataTypeTag[operation.C.element], - "layout_c": LayoutTag[operation.C.layout], - "element_accumulator": DataTypeTag[operation.accumulator_type()], - "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - "arch": "cutlass::arch::Sm%d" % operation.arch, - "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), - "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), - "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), - "warp_shape_m": str(warp_shape[0]), - "warp_shape_n": str(warp_shape[1]), - "warp_shape_k": str(warp_shape[2]), - "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), - "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), - "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), - "epilogue_vector_length": str(epilogue_vector_length), - "epilogue_functor": operation.epilogue_functor.emit(), - "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], - "stages": str(operation.tile_description.stages), - "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm], - "iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), - "stride_support": StrideSupportTag[operation.stride_support], - "math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation], - "align_a": str(operation.A.alignment), - "align_b": str(operation.B.alignment), - } - - if operation.emission_type == EmissionType.Kernel: - conv2d_template = self.template - else: - conv2d_template = self.template_device - - return SubstituteTemplate(conv2d_template, values) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py deleted file mode 100644 index 49ad79c9c8ecc9cad6067a3d9543b2625344848b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py +++ /dev/null @@ -1,541 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import ctypes - -from cutlass_library import SubstituteTemplate -import numpy as np - -from cutlass_library import DataType, DataTypeTag -from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory -from cutlass_cppgen.backend.frontend import NumpyFrontend -from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag -from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor - -dtype2ctype = { - DataType.f16: ctypes.c_uint16, - DataType.bf16: ctypes.c_uint16, - DataType.f32: ctypes.c_float, - DataType.f64: ctypes.c_double, - DataType.s8: ctypes.c_int8, - DataType.s32: ctypes.c_int32 -} - -if is_torch_available(): - import torch - import torch.nn.functional as F - - -def get_scalar(value): - """ - Returns a scalar value from a container (e.g., np.ndarray) - """ - if is_numpy_tensor(value): - if value.size != 1: - raise Exception("Scalars used in epilogue must be of size 1") - return value.reshape(-1)[0] - elif is_torch_tensor(value): - if value.size != 1: - raise Exception("Scalars used in epilogue must be of size 1") - return value.reshape(-1)[0] - else: - return value - - -def to_ctype_value(value, dtype): - """ - Converts ``value`` to the corresponding storage needed for the ctype that - will store ``value``. - """ - scalar = get_scalar(value) - if dtype == DataType.f16: - # Convert f16 value into an integer - return int.from_bytes(np.float16(scalar).tobytes(), "little") - else: - return scalar - - -################################################################################################# -# -# Epilogue Functors -# -################################################################################################# - - -class EpilogueFunctorBase: - """ - Base class for thread-level epilogue functors - """ - - def __init__(self) -> None: - pass - - def emit(self, tag, template_argument): - template = """${tag}<${arguments}>""" - arguments = "" - for idx, arg in enumerate(template_argument): - arguments += arg - if idx < len(template_argument) - 1: - arguments += ", " - values = { - "tag": tag, - "arguments": arguments, - } - - return SubstituteTemplate(template, values) - - -class LinearCombination(EpilogueFunctorBase): - """ - Apply a linear combination operator to an array of elements - D = alpha * accumulator + beta * source - - :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. - Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes - when there are not enough data to store - - :param element_accumulator: Accumulator data type - - :param element_epilogue: data type used to compute linear combination - """ - - tag = "cutlass::epilogue::thread::LinearCombination" - - def __init__( - self, element_output, epilogue_vector_length, - element_accumulator=None, element_epilogue=None) -> None: - super().__init__() - - if element_accumulator is None: - element_accumulator = element_output - if element_epilogue is None: - element_epilogue = element_output - - self.element_output = element_output - self.element_accumulator = element_accumulator - self.element_epilogue = element_epilogue - self.epilogue_vector_length = epilogue_vector_length - - self.template_arguments = [ - DataTypeTag[element_output], - str(epilogue_vector_length), - DataTypeTag[element_accumulator], - DataTypeTag[element_epilogue], - ] - - c_element_epilogue = dtype2ctype[self.element_epilogue] - element_epilogue = self.element_epilogue - - class _EpilogueOutputOpParamsEVT(ctypes.Structure): - """ - Epilogue params when using the default linear combination of EVT, which - does not currently use {alpha,beta}_ptr_array - """ - - stride_type = tuple_factory((0,0,1), "int64_t", [0]) - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ("dalpha", stride_type), - ("dbeta", stride_type), - ] - - def __init__(self, alpha, beta, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ("alpha_ptr_array", ctypes.c_void_p), - ("beta_ptr_array", ctypes.c_void_p), - ] - - def __init__(self, alpha, beta, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - - def to_evt_params(self) -> _EpilogueOutputOpParamsEVT: - return _EpilogueOutputOpParamsEVT(self.alpha, self.beta) - - self.epilogue_type = _EpilogueOutputOpParams - self.epilogue_type_evt = _EpilogueOutputOpParamsEVT - - def emit(self): - return super().emit(self.tag, self.template_arguments) - - -class LinearCombinationClamp(LinearCombination): - """ - Applies a linear combination operator to an array of elements then clamps - the output before converting to the output element type. - - D = alpha * accumulator + beta * source + uniform - - :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. - Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes - when there are not enough data to store - - :param element_accumulator: Accumulator data type - - :param element_epilogue: data type used to compute linear combination - """ - - tag = "cutlass::epilogue::thread::LinearCombinationClamp" - - def __init__( - self, element_output, epilogue_vector_length, - element_accumulator=None, element_epilogue=None) -> None: - # Base constructor - super().__init__( - element_output, - epilogue_vector_length, - element_accumulator, - element_epilogue, - ) - - c_element_epilogue = dtype2ctype[self.element_epilogue] - element_epilogue = self.element_epilogue - - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ] - - def __init__(self, alpha, beta, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - - self.epilogue_type = _EpilogueOutputOpParams - - -class FastLinearCombinationClamp(EpilogueFunctorBase): - """ - Applies a linear combination operator to an array of elements then clamps - the output before converting to the output element type. - - D = alpha * accumulator + beta * source - - Note: The below method only when problem_size_K <= 256 for signed int8 gemm - or problem_size_K <= 128 for unsigned int8 gemm. The default approach is - above. - - :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. - Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes - when there are not enough data to store - """ - - tag = "cutlass::epilogue::thread::FastLinearCombinationClamp" - - def __init__(self, element_output, epilogue_vector_length, *args) -> None: - super().__init__() - - self.template_arguments = [ - DataTypeTag[element_output], str(epilogue_vector_length) - ] - - self.element_accumulator = DataType.s32 - self.element_epilogue = DataType.f32 - - # get epilogue output op - c_element_epilogue = dtype2ctype[self.element_epilogue] - element_epilogue = self.element_epilogue - - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ] - - def __init__(self, alpha, beta, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - - self.epilogue_type = _EpilogueOutputOpParams - - def emit(self): - return super().emit(self.tag, self.template_arguments) - - -class LinearCombinationGeneric(LinearCombination): - """ - Applies a linear combination operator followed by an activation function - to an array of elements. - - D = activation(alpha * accumulator + beta * source) - - :param activation_functor: input activation functor - - :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. - Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes - when there are not enough data to store - - :param element_accumulator: Accumulator data type - - :param element_epilogue: data type used to compute linear combination - """ - - tag = "cutlass::epilogue::thread::LinearCombinationGeneric" - - def __init__( - self, activation_functor, - element_output, epilogue_vector_length, - element_accumulator=None, element_epilogue=None) -> None: - super().__init__( - element_output, - epilogue_vector_length, - element_accumulator, - element_epilogue, - ) - - self.template_arguments = [ - activation_functor.emit()] + self.template_arguments - - self.activation_functor = activation_functor - self.element_epilogue = element_epilogue - - # get epilogue output op - self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue) - - -class ActivationFunctor: - """ - Base class for frequently used activation functions - """ - - @staticmethod - def numpy(x: np.ndarray): - raise NotImplementedError() - - @classmethod - def emit(cls): - return ActivationOpTag[cls.binding_type] - - @staticmethod - def epilogue_output_op(element_epilogue): - c_element_epilogue = dtype2ctype[element_epilogue] - - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ] - - def __init__(self, alpha, beta, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - - return _EpilogueOutputOpParams - -class ActivationMeta(type): - @classmethod - def __call__(cls, x, *args): - if is_numpy_tensor(x): - return cls.numpy(x, *args) - elif is_torch_tensor(x): - return cls.torch(x, *args) - else: - raise NotImplementedError("Unsupported tensor type") - - @classmethod - def numpy(cls, *args): - raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.") - - @classmethod - def torch(cls, *args): - raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.") - -############################################################################## -# identity operator -class identityMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - return x - - @classmethod - def torch(cls, x): - return x - -class identity(ActivationFunctor, metaclass=identityMeta): - binding_type = ActivationOp.Identity - - -############################################################################## -# ReLu operator -class reluMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - return np.where(x > 0, x, 0) - - @classmethod - def torch(cls, x): - return F.relu(x) - -class relu(ActivationFunctor, metaclass=reluMeta): - binding_type = ActivationOp.ReLU - - -############################################################################## -# Leaky ReLu operator -class leakyReLUMeta(ActivationMeta): - @classmethod - def numpy(cls, x, leaky_alpha): - return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha - - @classmethod - def torch(cls, x, leaky_alpha): - return F.leaky_relu(x, leaky_alpha) - -class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta): - binding_type = ActivationOp.LeakyReLU - - @staticmethod - def epilogue_output_op(element_epilogue): - c_element_epilogue = dtype2ctype[element_epilogue] - - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", c_element_epilogue), - ("beta", c_element_epilogue), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), - ("leaky_alpha", c_element_epilogue) - ] - - def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None: - self.alpha = to_ctype_value(alpha, element_epilogue) - self.beta = to_ctype_value(beta, element_epilogue) - self.alpha_ptr = 0 - self.beta_ptr = 0 - self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue) - - return _EpilogueOutputOpParams - - -############################################################################## -# Tanh operator -class tanhMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - return np.tanh(x) - - @classmethod - def torch(cls, x): - return torch.tanh(x) - -class tanh(ActivationFunctor, metaclass=tanhMeta): - binding_type = ActivationOp.Tanh - - -############################################################################## -# Sigmoid operator -class sigmoidMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - return 1.0 / (1.0 + np.exp(-x)) - - @classmethod - def torch(cls, x): - return F.sigmoid(x) - -class sigmoid(ActivationFunctor, metaclass=sigmoidMeta): - binding_type = ActivationOp.Sigmoid - - -############################################################################## -# SiLu operator -class siluMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - return x * sigmoidMeta.numpy() - - @classmethod - def silu(cls, x): - return F.silu(x) - - -class silu(ActivationFunctor, metaclass=siluMeta): - binding_type = ActivationOp.SiLU - - -############################################################################## -# Hardswish operator -class hardswishMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0) - return x * relu6 / 6.0 - - @classmethod - def torch(cls, x): - return F.hardswish(x) - - -class hardswish(ActivationFunctor, metaclass=hardswishMeta): - binding_type = ActivationOp.HardSwish - - -############################################################################## -# GELU operator -class geluMeta(ActivationMeta): - @classmethod - def numpy(cls, x): - from scipy.special import erf - return 0.5 * x * (1 + erf(x / np.sqrt(2.0))) - - @classmethod - def torch(cls, x): - return F.gelu(x) - - -class gelu(ActivationFunctor, metaclass=geluMeta): - binding_type = ActivationOp.Gelu diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py deleted file mode 100644 index b61e983ab23bb5662d15e185184efa227351446d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor -from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py deleted file mode 100644 index 945dcf80e307eb870f31722822f959da03e6c421..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter -import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes -from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter -import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes -from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter -import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py deleted file mode 100644 index 72a7d8c04db5c8df2595fab8befaa07bf238c2f2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py +++ /dev/null @@ -1,159 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Base class for Epilogue Visitor Emitter -""" - -from cutlass_library import DataTypeTag -from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR - - -class FusionCallbacks: - def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None: - """ - Emit the EVT fusion callbacks - :param dag_ir: the DAG IR holding the epilogue visitor - :param cc: compute capability - :param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks - For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API - so that their shared memory can be explicitly reused - For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes. - """ - self.dag_ir = dag_ir - self.emit_CD = emit_CD - self.cc = cc - self.evt_cc = 90 if cc >= 90 else cc - if self.cc < 90: - self.namespace = "threadblock" - else: - self.namespace = "fusion" - - # - # Helper functions - # - - def get_visitor_name(self, node: str): - """ - Get the visitor name - """ - meta = self.dag_ir.get_node_meta(node) - if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0: - return f"EVT{meta.name_camel}" - else: - return meta.name_camel - - def emit(self): - node_metas = self.dag_ir.node_metas_topological_order() - epilogue_str = "" - # Step 1: emit individual node type decl - # emit the EVT & DAG connector - for meta in node_metas: - if not meta.disabled: - epilogue_str += self.emit_node(meta) - if not self.emit_CD and meta.name == "D": - continue - if isinstance(meta, TopoVisitorNode): - epilogue_str += self.emit_dag(meta) - else: - epilogue_str += self.emit_evt(meta) - - # Step 2: post-processing & get callback name - if not self.emit_CD: - if not self.dag_ir.has_node("C"): - epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n" - output_node = self.dag_ir.get_all_inputs("D")[0] - # The callback is the src of node D - callback_name = self.get_visitor_name(output_node) - else: - # The callback is the last node in the topological order - callback_name = self.get_visitor_name(node_metas[-1].name) - return epilogue_str, callback_name - - def emit_evt(self, node): - if self.dag_ir.in_degree(node.name) == 0: - return "" - - evt_tmp = f""" -using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT< - {node.name_camel}, -""" - sorted_children = self.dag_ir.get_all_inputs(node.name) - evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children] - evt_tmp += ",\n".join(evt_node_strs) + ">;\n" - - return evt_tmp - - def emit_dag(self, node): - subgraph = node.subgraph - subgraph_nodes = subgraph.nodes_topological_order() - # Emit the Edge Tuple - edge_tuples = "cute::tuple<\n" - for n in subgraph_nodes[:-1]: - in_edges = subgraph.in_edges(n) - edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges] - sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] - edge_tuple = " cute::seq<" - edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children] - edge_tuple += ", ".join(edge_str) + ">,\n" - - edge_tuples += edge_tuple - edge_tuples += " >" - - # Emit the node list - dag_nodes = "" - dag_node_strs = [] - for n in subgraph_nodes[:-1]: - n_meta = subgraph.get_node_meta(n) - if n_meta.disabled: - dag_node_strs.append(f" {self.get_visitor_name(n)}") - else: - dag_node_strs.append(f" {n_meta.name_camel}") - dag_nodes = ",\n".join(dag_node_strs) - - return f""" -using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor< - {DataTypeTag[node.subgraph.element_compute]}, - {edge_tuples}, -{dag_nodes} ->; -""" - - def emit_node(self, node): - if isinstance(node, TopoVisitorNode): - emission = "" - for node in node.subgraph.node_metas_topological_order(): - if not node.disabled: - emission += self.emit_node(node) - return emission - else: - return node.underlying_impl.type_decl diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py deleted file mode 100644 index db521e5279c57734a8e408938dc6ea95a608c6d8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py +++ /dev/null @@ -1,116 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Emitter for Sm100 Epilogue Visitor -""" - -from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag -from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape -from cutlass_cppgen.backend import GemmOperationUniversal -from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks -from cutlass_cppgen.backend.evt.ir.node import TupleEmitter - - -class Sm100CollectiveEpilogue: - def __init__(self, tile_description, - kernel_schedule, - epilogue_schedule, - element_accumulator, - element_d, - fusion_callbacks) -> None: - - self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule) - self.element_accumulator = element_accumulator - if fusion_callbacks.dag_ir.has_node("C"): - self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element - else: - self.element_c = DataType.void - self.element_d = element_d - self.schedule = epilogue_schedule - self.fusion_callbacks = fusion_callbacks - self.opclass = tile_description.math_instruction.opcode_class - - @property - def CtaTileMNK(self) -> str: - """ - The threadblock shape - """ - return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" - - @property - def EpilogueTileType(self) -> str: - """ - The epilogue tile type - """ - return "cutlass::epilogue::collective::EpilogueTileAuto" - - @property - def Schedule(self) -> str: - return EpilogueScheduleTag[self.schedule] - - def emit(self): - tuple_emitter = TupleEmitter("int64_t") - stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl - stride_C_str = stride_D_str - if self.fusion_callbacks.dag_ir.has_node("C"): - stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl - - callback_decl, callback_name = self.fusion_callbacks.emit() - return callback_name, f""" -using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< - {OpcodeClassTag[self.opclass]}, - {self.CtaTileMNK}, {self.EpilogueTileType}, - {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, - {self.Schedule}, {stride_C_str}, {stride_D_str}, - false /* IsPerColScaleSupported */, - false /* IsBlockScaleSupported */ ->; -{callback_decl} -""" - - -class Sm100Emitter: - def __init__(self, operation: GemmOperationUniversal, graph) -> None: - fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False) - - self.collective_epilogue = Sm100CollectiveEpilogue( - tile_description=operation.tile_description, - kernel_schedule=operation.tile_description.kernel_schedule, - epilogue_schedule=operation.tile_description.epilogue_schedule, - element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_d=fusion_callbacks.dag_ir.get_node_meta("D").element, - fusion_callbacks=fusion_callbacks - ) - - def emit(self): - return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py deleted file mode 100644 index 33e77b4c9f2efbef808f8551e4402f5a6761ea4a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py +++ /dev/null @@ -1,134 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from pycute import product - -from cutlass_library import DataTypeSize, DataTypeTag - -from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl -import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes - -from cutlass_cppgen.backend.library import FloatRoundStyleTag - - -Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl -Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl -Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl -Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl -Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl -Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl -Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl -Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl -Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl -Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl - - -class Sm100AuxLoadImpl(AuxLoadImpl): - - @property - def descriptor(self) -> str: - """ - Descriptor for Aux Load - """ - return f"{self.name_camel}Descriptor" - - def decl_descriptor(self) -> str: - """ - Declare the descriptor type - """ - return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor;\n" - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = self.decl_descriptor() - self._type_decl += f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< - {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, - {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R ->; -""" - return self._type_decl - - def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): - """ - Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d - """ - return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) - - -class Sm100AuxStoreImpl(AuxStoreImpl): - - @property - def descriptor(self) -> str: - """ - Descriptor for Aux Load - """ - return f"{self.name_camel}Descriptor" - - def decl_descriptor(self) -> str: - """ - Declare the descriptor type - """ - return f""" -using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< - EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} ->; -""" - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = self.decl_descriptor() - self._type_decl += f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< - {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, - {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, - typename {self.descriptor}::CopyOpR2S ->; -""" - return self._type_decl - - def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): - """ - Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d - """ - return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py deleted file mode 100644 index 868453a7cf5049e5899bf6aef419485a1a5dbb43..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py +++ /dev/null @@ -1,47 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Emitter for Sm80 Epilogue Visitor -""" - -from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks -from cutlass_cppgen.backend import GemmOperationUniversal - - -class Sm80Emitter: - def __init__(self, operation: GemmOperationUniversal, graph) -> None: - self.fusion_callbacks = FusionCallbacks(graph, cc=80) - - def emit(self): - callback_decl, callback_name = self.fusion_callbacks.emit() - return callback_name, callback_decl diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py deleted file mode 100644 index b9fc561354a471f4f97600b27e4dbb21950a9e79..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py +++ /dev/null @@ -1,258 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_library import DataTypeSize, DataTypeTag - -from cutlass_cppgen.backend.evt.ir import ( - # Load Node - AccumulatorImpl, - AuxLoadImpl, - ColumnBroadcastImpl, - LoadNode, - LoadSrcImpl, - RowBroadcastImpl, - ScalarBroadcastImpl, - # Compute Node - ComputeImpl, - # Store Node - AuxStoreImpl, - ColumnReductionImpl, - RowReductionImpl, - ScalarReductionImpl -) - -from cutlass_cppgen.backend.library import ( - FloatRoundStyleTag, - FunctionalOp, - op_tag, -) - - -class Sm80AccumulatorImpl(AccumulatorImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n""" - return self._type_decl - - -class Sm80AuxLoadImpl(AuxLoadImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad< - OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80LoadSrcImpl(Sm80AuxLoadImpl): - pass - - -class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl): - def __init__(self, node: LoadNode) -> None: - super().__init__(node) - self.broadcast_count = 1 - self.reduction_fn = FunctionalOp.Multiplies - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast< - {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} ->; -""" - return self._type_decl - - -class Sm80RowBroadcastImpl(RowBroadcastImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, {DataTypeTag[self.element]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, {DataTypeTag[self.element]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80ComputeImpl(ComputeImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute< - {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, - {FloatRoundStyleTag[self.round_style]} ->; -""" - return self._type_decl - - -class Sm80AuxStoreImpl(AuxStoreImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80StoreDImpl(Sm80AuxStoreImpl): - pass - - -class Sm80ColumnReductionImpl(ColumnReductionImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, - OutputTileThreadMap, {DataTypeTag[self.element]}, - {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80RowReductionImpl(RowReductionImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, - OutputTileThreadMap, {DataTypeTag[self.element]}, - {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm80ScalarReductionImpl(ScalarReductionImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, - OutputTileThreadMap, {DataTypeTag[self.element]}, - {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py deleted file mode 100644 index 3c058aa8f30a56d97ce3c3600f7c89189e7a15ad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py +++ /dev/null @@ -1,98 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Emitter for Sm90 Epilogue Visitor -""" - -from cutlass_library import DataTypeTag, EpilogueScheduleTag -from cutlass_cppgen.backend import GemmOperationUniversal -from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks - - -class CollectiveEpilogue: - def __init__(self, tile_description, - schedule, - element_c, - element_d, - fusion_callbacks) -> None: - - self.cta_tile_mnk = tile_description.threadblock_shape - self.element_c = element_c - self.element_d = element_d - self.schedule = schedule - self.fusion_callbacks = fusion_callbacks - - @property - def CtaTileMNK(self) -> str: - """ - The threadblock shape - """ - return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" - - @property - def EpilogueTileType(self) -> str: - """ - The epilogue tile type - """ - return "cutlass::epilogue::collective::EpilogueTileAuto" - - @property - def Schedule(self) -> str: - return EpilogueScheduleTag[self.schedule] - - def emit(self): - callback_decl, callback_name = self.fusion_callbacks.emit() - return callback_name, f""" -using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< - {self.CtaTileMNK}, {self.EpilogueTileType}, - {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, - {self.Schedule} ->; -{callback_decl} -""" - - -class Sm90Emitter: - def __init__(self, operation: GemmOperationUniversal, graph) -> None: - fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False) - - self.collective_epilogue = CollectiveEpilogue( - tile_description=operation.tile_description, - schedule=operation.tile_description.epilogue_schedule, - element_c=operation.C.element, - element_d=operation.C.element, - fusion_callbacks=fusion_callbacks - ) - - def emit(self): - return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py deleted file mode 100644 index 43601a424e3ecb175837fb31389436c1470d9c0b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py +++ /dev/null @@ -1,329 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from pycute import product - -from cutlass_library import DataTypeSize, DataTypeTag -from cutlass_cppgen.backend.evt.ir import ( - # Load Node - AccumulatorImpl, - AuxLoadImpl, - ColumnBroadcastImpl, - LoadNode, - LoadSrcImpl, - RowBroadcastImpl, - ScalarBroadcastImpl, - # Compute Node - ComputeImpl, - ComputeNode, - # Store Node - AuxStoreImpl, - ColumnReductionImpl, - RowReductionImpl, - ScalarReductionImpl, - StoreNode, - StoreDImpl, -) -from cutlass_cppgen.backend.library import ( - FloatRoundStyleTag, - FunctionalOp, - op_tag, -) - - -class Sm90AccumulatorImpl(AccumulatorImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n""" - return self._type_decl - - -class Sm90LoadSrcImpl(LoadSrcImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using ElementC = {DataTypeTag[self.element]}; -using StrideC = {self.stride_mnl}; -using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>; -""" - return self._type_decl - - -class Sm90AuxLoadImpl(AuxLoadImpl): - - @property - def descriptor(self) -> str: - """ - Descriptor for Aux Load - """ - return f"{self.name_camel}Descriptor" - - def decl_descriptor(self) -> str: - """ - Declare the descriptor type - """ - return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor;\n" - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = self.decl_descriptor() - self._type_decl += f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< - {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, - {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R ->; -""" - return self._type_decl - - def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): - """ - Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d - """ - return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) - - -class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl): - def __init__(self, node: LoadNode) -> None: - super().__init__(node) - self.broadcast_count = 1 - self.reduction_fn = FunctionalOp.Multiplies - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast< - {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} ->; -""" - return self._type_decl - - -class Sm90RowBroadcastImpl(RowBroadcastImpl): - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm90ComputeImpl(ComputeImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute< - {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, - {FloatRoundStyleTag[self.round_style]} ->; -""" - return self._type_decl - - -class Sm90AuxStoreImpl(AuxStoreImpl): - - @property - def descriptor(self) -> str: - """ - Descriptor for Aux Load - """ - return f"{self.name_camel}Descriptor" - - def decl_descriptor(self) -> str: - """ - Declare the descriptor type - """ - return f""" -using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor< - EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} ->; -""" - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = self.decl_descriptor() - self._type_decl += f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< - {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, - {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, - typename {self.descriptor}::CopyOpR2S ->; -""" - return self._type_decl - - def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): - """ - Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d - """ - return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) - - -class Sm90StoreDImpl(StoreDImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - return f""" -using ElementD = {DataTypeTag[self.element]}; -using StrideD = {self.stride_mnl}; -""" - - -class Sm90ColumnReductionImpl(ColumnReductionImpl): - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0, - typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, - {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm90RowReductionImpl(RowReductionImpl): - - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */, - typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, - {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, - {self.stride_mnl} ->; -""" - return self._type_decl - - -class Sm90ScalarReductionImpl(ScalarReductionImpl): - - - @property - def type_decl(self): - """ - Return the string defining the type - """ - if self._type_decl is not None: - return self._type_decl - - self._type_decl = f""" -using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction< - {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, - {DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]}, - {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl} ->; -""" - return self._type_decl diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py deleted file mode 100644 index da446e76d9ebd9de04950a89b2451480492147a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py +++ /dev/null @@ -1,168 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Epilogue Visitor interface for compiling, and running visitor-based epilogue. -""" - -import ctypes - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -from cutlass_library import DataType -import numpy as np - -from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase -import cutlass_cppgen.backend.evt.backend -from cutlass_cppgen.backend.frontend import TensorFrontend -from cutlass_cppgen.utils.datatypes import is_numpy_tensor -from cutlass_cppgen.backend.evt.passes.util import cc_map - - -class EpilogueFunctorVisitor(EpilogueFunctorBase): - """ - Apply an epilogue functor described by the epilogue EVT - - :param cc: compute capability - :param visitor_frontend: user-provide visitor frontend - - """ - def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: - # Type of Emitter based on CC - self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") - - # Visitor Types - self.visitor = visitor - self.graph = visitor.dag_ir - - # Data types - self.element_epilogue = element_compute # element compute - self.element_output = self.graph.get_node_meta('D').underlying_impl.element - - # Epilogue Thread Type - epilogue_thread_type = self.visitor.epilogue_thread_type - if cc_map[cc] in [90, 100]: - self.arg_c_type = self.visitor.arg_c_type - self.arg_d_type = self.visitor.arg_d_type - output_names = self.visitor.return_names - reduction_names = self.visitor.reduction_names - - # Epilogue stages specialized for sm80 kernel - if cc == 80: - if hasattr(self.visitor, "epilogue_stages"): - self.epilogue_stages = self.visitor.epilogue_stages - assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue" - - # Epilogue Argument Type - class _Arguments(ctypes.Structure): - """ - Concepts: - class _EpilogueArguments(ctypes.Structure): - _fields_ = [ - ("epilogue", _Arguments), <- this class - ("ptr_C", ctypes.c_void_p), - ("stride_C", StrideBatched_), - ("ptr_D", ctypes.c_void_p), - ("stride_D", StrideBatched_) - ] - """ - _fields_ = [ - ("output_op", epilogue_thread_type) - ] - - def __init__(self, kwargs: dict) -> None: - # The user-input kwargs is a dict of (name: tensors) - # We first convert all of them to device pointers - ptr_kwargs = {} - for key in kwargs.keys(): - is_output = key in output_names and key not in reduction_names - ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output) - # Initialize the thread arguments - self.output_op = epilogue_thread_type(ptr_kwargs) - - def get_tensor_ptr(self, tensor_name, kwargs, is_output=False): - """ - Helper function for extracting device pointer - """ - # Skip the special tensors - if cc in [90, 100]: - if tensor_name in ["C", "D"]: - return 0 - if tensor_name not in kwargs.keys(): - raise ValueError(f"Tensor {tensor_name} is not provided.") - tensor = kwargs[tensor_name] - - # For float scalar constant, directly return the value - if isinstance(tensor, float): - return tensor - - # The tensor frontend returns a device buffer for np.ndarray - # and device ptr for other frontends - buffer_or_ptr = TensorFrontend.argument(tensor, is_output) - if is_numpy_tensor(tensor): - # Remember the host tensor for later synchronization - setattr(self, f"{tensor_name}_buffer", buffer_or_ptr) - setattr(self, f"{tensor_name}_host", tensor) - return int(buffer_or_ptr.ptr) - else: - return int(buffer_or_ptr) - - def sync(self): - """ - Synchronize the results from device to host - """ - for name in output_names: - if hasattr(self, f"{name}_host"): - host_tensor = getattr(self, f"{name}_host") - tensor_ptr = getattr(self, f"{name}_buffer").ptr - (err,) = cuda.cuMemcpyDtoH( - host_tensor, - tensor_ptr, - host_tensor.size * host_tensor.itemsize, - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - self.epilogue_type = _Arguments - - def emit(self, operation): - """ - Emit the C++ code - """ - emitter = self.emit_cls(operation, self.graph) - return emitter.emit() - - def get_smem_size(self, tile_description): - """ - Get the shared memory size in bytes - """ - return self.visitor.get_smem_size(tile_description) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py deleted file mode 100644 index f2323278ed232adea205e41b901c62a268e56976..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py deleted file mode 100644 index 213aafdbe3f922f22186e37ac9f2eefea74e71ce..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py +++ /dev/null @@ -1,272 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Base class for Python EVT Frontend -""" - -from typing import Union - -from cutlass_library import DataType -from cutlass_cppgen.backend.evt.ir import ( - ComputeNode, - DAGIR, - LayoutNode, - LoadNode, - StoreNode, -) -from cutlass_cppgen.backend.evt.passes import ( - EVTGraphDrawer, - EVTPassManager, - GetSmemSize, - PassDAG2Tree, - PassGetArgumentType, - PassGetImpl, - PassFixElementD, - PassLayoutManipulateElimination, - PassPreprocessRed, - PassShapeTypePropagation, -) -from cutlass_cppgen.backend.evt.passes.util import cc_map -from cutlass_cppgen.backend.utils import device_cc -from cutlass_cppgen.epilogue.evt_ops import permute, reshape -from cutlass_cppgen.utils.datatypes import library_type - - -class EVTFrontendBase: - layout_fns = { - "permute": permute, - "reshape": reshape - } - - def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None: - self.cc = cc - self.element_compute = library_type(element_compute) - self.dag_ir = DAGIR(self.cc, self.element_compute) - self.compute_cnt = 0 - self.layout_cnt = 0 - self.imm_cnt = 0 - - self.pass_manager = EVTPassManager( - self.dag_ir, - [ - PassPreprocessRed, - PassGetArgumentType, - PassShapeTypePropagation, - PassLayoutManipulateElimination, - PassGetImpl, - PassDAG2Tree, - PassFixElementD - ] + additional_passes) - - if self.cc == 80: - self._epilogue_stages = 1 - else: - self._epilogue_stages = None - - @property - def epilogue_stages(self): - return self._epilogue_stages - - @epilogue_stages.setter - def epilogue_stages(self, stages): - self._epilogue_stages = stages - - - def parse(self, *args, **kwargs): - raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class") - - def trace(self, *args, **kwargs): - # Parse the input - self.parse(*args, **kwargs) - - # Verify the DAG IR to ensure that "D" is the output node with out_degree = 0 - if (self.cc >= 90): - if (self.dag_ir.out_degree("D") != 0): - raise RuntimeError( - f"On SM90 or higher, D is expected to be a output node with 0 users to " - f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}") - - # Run the passes - self.pass_manager() - # Set the epilogue type - self.epilogue_thread_type = self.dag_ir.epilogue_thread_type - if cc_map[self.cc] in [90, 100]: - self.arg_c_type = self.dag_ir.arg_c_type - self.arg_d_type = self.dag_ir.arg_d_type - self.reduction_names = self.dag_ir.reduction_names - - # - # Helper functions for DAG IR manipulation - # - - def add_node(self, node): - self.dag_ir.add_node(node) - - def add_edge(self, src, tgt, weight=0): - self.dag_ir.add_edge(src, tgt, weight=weight) - - def set_tensor(self, node_name, example): - """ - Add an example tensor to node {node_name} in the DAG IR - """ - meta = self.dag_ir.get_node_meta(node_name) - meta.tensor = {"tensor": example} - - def set_store_tensor(self, node_name, example): - """ - Add an example tensor to node {node_name} in the DAG IR - """ - meta = self.dag_ir.get_node_meta(node_name) - meta.store_tensor = {"tensor": example} - - def mark_output(self, node_name): - """ - Mark a store node as output - """ - meta = self.dag_ir.get_node_meta(node_name) - if not isinstance(meta, StoreNode): - raise ValueError( - f"Only StoreNodes can be marked as output. " - f"Got {type(meta).__name__}: {node_name}") - meta.is_output = True - - # Add node with specific type - - def add_load_node(self, name, example): - """ - Add a Load node to DAG IR - :param name: name of the loaded variable - :type name: str - :param example: example input - :type example: np.ndarray|torch.Tensor|cupy.ndarray|float - """ - if name is None: - raise ValueError(f"Name is not provided.") - if example is None: - raise ValueError(f"Example input for {name} is not provided.") - load_node = LoadNode(name) - load_node.tensor = {"tensor": example} - # Special logics for accumulator - if name == "accum": - if load_node.tensor.rank == 2: - new_shape = tuple([1, ] + list(load_node.tensor.shape)) - load_node.tensor.broadcast(new_shape) - elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3: - raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.") - self.add_node(load_node) - - def add_imm(self, value: Union[float,int]): - """ - Add an immediate scalar value to DAG IR - :param value: the value of the immediate scalar - :type value: float - """ - try: - value = float(value) - except: - raise ValueError(f"{type(value).__name__} cannot be converted to float.") - - name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_') - self.imm_cnt += 1 - load_node = LoadNode(name) - load_node.tensor = {"tensor": value, "is_constant": True} - self.add_node(load_node) - return name - - def add_compute_node(self, op, name=None): - """ - Add a compute node. - :param op: the computation op - :param name: the node name (optional) - :type name: str - :return: the name of the compute node - """ - if name is None: - name = f"compute_{self.compute_cnt}" - self.compute_cnt += 1 - compute_node = ComputeNode( - name=name, fn=op, - element_output=self.element_compute, - element_compute=self.element_compute) - self.add_node(compute_node) - return compute_node.name - - def add_layout_node(self, op, kwargs, name=None): - """ - Add a layout node. - :param op: the layout op - :type op: evt_ops - :param name: the node name (optional) - :type name: str - :return: the name of the layout node - """ - if name is None: - name = f"layout_{self.layout_cnt}" - self.layout_cnt += 1 - layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs) - self.add_node(layout_node) - return layout_node.name - - def add_store_node(self, name): - store_node = StoreNode(name) - self.add_node(store_node) - - # - # Visualization The DAG IR - # - - def visualize(self, name="dag_ir"): - """ - Visualize the dag ir with svg file - :param name: the name of the graph - """ - drawer = EVTGraphDrawer(self.dag_ir, name) - try: - for name, graph in drawer.get_dot_graph(): - graph.write_svg(f"./{name}.svg") - except: - raise RuntimeError( - "'dot' is not found in path. GraphDrawer is disabled. " - "Please install it with 'sudo apt-get install graphviz'." - ) - - # - # Get shared memory size - # - - def get_smem_size(self, tile_description): - """ - Get the shared memory size of the epilogue - """ - smem_size = GetSmemSize(self.dag_ir)(tile_description) - return smem_size diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py deleted file mode 100644 index 8727b754cd2b9a557d45760cb0a24a43619a373f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py +++ /dev/null @@ -1,194 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Python AST frontend that parses input into DAG IR -""" - -import ast -import inspect -import textwrap - -from cutlass_library import DataType - -import cutlass_cppgen -from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase -from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu -from cutlass_cppgen.backend.library import FunctionalOp - - -class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): - def __init__(self, cc, element_compute=DataType.f32, **kwargs): - super().__init__(cc, element_compute, **kwargs) - # Flags - # If this state is True, visit_Constant returns values without creating imm node - self.no_imm = False - self.visiting_return = False - - def parse(self, example_inputs): - self.example_inputs = example_inputs - self.source = textwrap.dedent(inspect.getsource(self.__call__)) - self.ast = ast.parse(self.source) - self.visit(self.ast) - - # - # Helper functions - # - @staticmethod - def ast_op_to_bindings(op): - mapping = { - ast.Add: FunctionalOp.Plus, - ast.Sub: FunctionalOp.Minus, - ast.Mult: FunctionalOp.Multiplies, - ast.Div: FunctionalOp.Divides, - "maximum": FunctionalOp.Maximum, - "minimum": FunctionalOp.Minimum, - "identity": identity.binding_type, - "relu": relu.binding_type, - "tanh": tanh.binding_type, - "sigmoid": sigmoid.binding_type, - "silu": silu.binding_type, - "hardswish": hardswish.binding_type, - "gelu": gelu.binding_type, - "multiply_add": FunctionalOp.MultiplyAdd, - "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), - "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum), - "exp": FunctionalOp.Exp - } - return mapping[op] - - # - # Visiting different node types - # - - def visit_FunctionDef(self, node: ast.FunctionDef): - # Visit args and register load nodes - for arg in node.args.args: - self.visit(arg) - for expr in node.body: - self.visit(expr) - - def visit_arg(self, node: ast.arg): - # Name of the argument - name = node.arg - try: - example_tensor = self.example_inputs[name] - except: - raise RuntimeError(f"Example input for {name} is not provided.") - - self.add_load_node(name, example_tensor) - - def visit_Name(self, node: ast.Name): - return node.id - - def visit_Constant(self, node: ast.Constant): - if self.no_imm: - return node.value - else: - name = self.add_imm(node.value) - return name - - def visit_Tuple(self, node: ast.Tuple): - results = [] - for elt in node.elts: - results.append(self.visit(elt)) - return tuple(results) - - def visit_keyword(self, node: ast.keyword): - return {node.arg: self.visit(node.value)} - - def visit_BinOp(self, node: ast.BinOp): - if self.visiting_return: - raise SyntaxError("Return value cannot be an expression") - lhs = self.visit(node.left) - rhs = self.visit(node.right) - op = self.ast_op_to_bindings(type(node.op)) - name = self.add_compute_node(op) - - # Add edges - # The edge weights are used to sort the input args - self.add_edge(lhs, name, weight=0) - self.add_edge(rhs, name, weight=1) - return name - - def visit_Assign(self, node: ast.BinOp): - target = self.visit(node.targets[0]) - value = self.visit(node.value) - # Create the assign node - self.add_store_node(target) - - # Add edges - self.add_edge(value, target) - return target - - def visit_Call(self, node: ast.Call): - if self.visiting_return: - raise SyntaxError("Return value cannot be an expression") - func = self.visit(node.func) - args = [self.visit(arg) for arg in node.args] - - if func in self.layout_fns.keys(): - # Parse kwargs - # By default, visiting imm automatically creates a load node - # However, in function call, keyword args are used to set - # specific function attributes such as indices for permute - # So no_imm is set to True temporarily - self.no_imm = True - kwargs = {} - for kw in node.keywords: - kwargs.update(self.visit(kw)) - self.no_imm = False - op = self.layout_fns[func] - name = self.add_layout_node(op, kwargs) - else: - op = self.ast_op_to_bindings(func) - name = self.add_compute_node(op) - - # Add edges - for idx, arg in enumerate(args): - self.add_edge(arg, name, weight=idx) - return name - - def visit_Return(self, node: ast.Return): - self.visiting_return = True - results = self.visit(node.value) - self.visiting_return = False - self.return_names = results - if not isinstance(results, tuple): - results = (results,) - for rst in results: - try: - example_tensor = self.example_inputs[rst] - except: - raise RuntimeError(f"Example input for {rst} is not provided.") - self.set_store_tensor(rst, example_tensor) - self.mark_output(rst) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py deleted file mode 100644 index 0f9e3f811a020164dc5ec5eb4a8dfaf3dc5728fe..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl -from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR -from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode -from cutlass_cppgen.backend.evt.ir.load_nodes import ( - LoadNode, - AccumulatorImpl, - LoadSrcImpl, - AuxLoadImpl, - RowBroadcastImpl, - ColumnBroadcastImpl, - ScalarBroadcastImpl -) -from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl -from cutlass_cppgen.backend.evt.ir.store_nodes import ( - StoreNode, - StoreDImpl, - AuxStoreImpl, - ColumnReductionImpl, - RowReductionImpl, - ScalarReductionImpl -) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py deleted file mode 100644 index 02b05358648694dcf2a5afd7117e6fca6a2d136c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py +++ /dev/null @@ -1,91 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Python registration for compute nodes in EVT -""" - -from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase -from cutlass_cppgen.backend.library import FloatRoundStyle - - -class ComputeImplBase(ImplBase): - """ - Base class for compute implementation - """ - def __init__(self, node) -> None: - super().__init__(node) - - -class ComputeImpl(ComputeImplBase): - """ - Implementation for Compute Node - """ - def __init__(self, node) -> None: - super().__init__(node) - - self.fn = node.fn - self.element_output = node.element_output - self.element_compute = node.element_compute - self.round_style = node.round_style - - @staticmethod - def match(node, problem_size: tuple): - return True - - -class ComputeNode(NodeBase): - """ - Compute Node in DAG IR - """ - possible_impls = [ - ComputeImpl - ] - def __init__( - self, name: str, fn, element_output, - element_compute, - round_style=FloatRoundStyle.ToNearest) -> None: - super().__init__(name) - self.op = "compute" - self.fn = fn - self.element_compute = element_compute - self.round_style = round_style - - def type_propagation(self, *args, **kwargs): - """ - Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. - """ - self.element = self.element_compute - # In general, the compute nodes have element_output = element_compute - # In certain cases like producer of D it is overwritten by other passes - if not hasattr(self, "element_output"): - self.element_output = self.element diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py deleted file mode 100644 index e7e9f75a9727306d56c049bd491a95542a68bec8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py +++ /dev/null @@ -1,254 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -DAG IR used by Python EVT -""" - -import networkx as nx - -from cutlass_library import DataType - -from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode -from cutlass_cppgen.backend.evt.ir.node import NodeBase -from cutlass_cppgen.backend.library import ActivationOp -from cutlass_cppgen.backend.utils import device_cc - - -class DAGIR: - """ - ``DAGIR`` is the main data structure used in the EVT Intermediate Representation. - It consists of a series of ``Node`` s, each representing epilogue visitor nodes. - - In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node - """ - def __init__(self, cc, element_compute=DataType.f32) -> None: - # The EVT DAGIR is managed through the nextworkX Digraph class - self._graph = nx.DiGraph() - - self.element_compute = element_compute - - self.reduction_names = [] - - self.cc = cc - - self.identity_counter = 0 - - # - # IR manipulator - # - - def add_node(self, meta: NodeBase): - """ - Add a node to dag ir - """ - if self.has_node(meta.name): - raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.") - self._graph.add_node(meta.name, meta=meta) - - def add_edge(self, src: str, dst: str, weight: int=0): - """ - Add an edge src -> dst to dag ir with weight - """ - if not self.has_node(src): - raise SyntaxError(f"Variable '{src}' is undefined.") - if not self.has_node(dst): - raise SyntaxError(f"Variable '{dst}' is undefined.") - - if self._graph.has_edge(src, dst): - # The DiGraph doesn't support multiple edges between two nodes - # We insert an identity node in such case as a workaround - identity_name = f"autogen_identity_{self.identity_counter}" - self.identity_counter += 1 - compute_node = ComputeNode( - name=identity_name, fn=ActivationOp.Identity, - element_output=self.element_compute, - element_compute=self.element_compute) - self.add_node(compute_node) - self.add_edge(src, identity_name, 0) - self.add_edge(identity_name, dst, weight) - else: - self._graph.add_edge(src, dst, weight=weight) - - def remove_node(self, node: str): - """ - Remove node from dag ir - """ - self._graph.remove_node(node) - - def remove_edge(self, src: str, dst: str): - """ - Remove edge src -> dst - """ - self._graph.remove_edge(src, dst) - - # - # Helper functions for getting attrs - # - - def has_node(self, node: str) -> bool: - """ - Check if the node is in the graph - """ - return self._graph.has_node(node) - - def in_degree(self, node: str): - """ - Get the input degree of node - """ - return self._graph.in_degree(node) - - def in_edges(self, node: str): - """ - Get the input edges of node - """ - return [edge for edge in self._graph.in_edges(node)] - - def out_degree(self, node: str): - """ - Get the output degree of node - """ - return self._graph.out_degree(node) - - def out_edges(self, node: str): - """ - Get the output edges of node - """ - return [edge for edge in self._graph.out_edges(node)] - - def get_node_meta(self, node: str): - """ - Get the meta data of the node - """ - return self._graph.nodes[node]["meta"] - - def get_edge_weight(self, src, dst): - """ - Get the edge weight of edge src->dst - """ - return self._graph.get_edge_data(src, dst)["weight"] - - # - # High-level helper functions - # - - def all_reachable_nodes(self, node: str): - """ - Get all the nodes reachable from the current node (exclude) - """ - return list(nx.dfs_preorder_nodes(self._graph, source=node)) - - def get_users(self, node: str): - """ - Get all users of the current node - """ - return [edge[1] for edge in self.out_edges(node)] - - def get_all_inputs(self, node: str): - """ - Get all the input nodes sorted by edge weight - """ - in_edges = self.in_edges(node) - edge_weights = [self.get_edge_weight(*edge) for edge in in_edges] - return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] - - def get_all_inputs_meta(self, node: str): - """ - Get all the input node metas sorted by edge weight - """ - return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)] - - def replace_all_uses_with(self, node1, node2): - """ - Replace all uses of node1 with node2 - """ - for edge in self.out_edges(node1): - weight = self.get_edge_weight(*edge) - user = edge[1] - self.add_edge(node2, user, weight) - self.remove_edge(node1, user) - self.remove_node(node1) - - # - # Node accessor - # - def nodes_topological_order(self): - """ - Get the nodes in the unique lexicographical topological order - It generates a unique ordering of nodes by first sorting topologically - and then additionally by sorting lexicographically. - - Although topological_sort alone also works, this generates a unique key - for each epilogue visitor pattern and ensures the compilation cache can be reused. - :return: list[str] - """ - return list(nx.lexicographical_topological_sort(self._graph)) - - def node_metas_topological_order(self): - """ - Get the node metas in topological order - :return: list[NodeBase] - """ - return [self.get_node_meta(node) for node in self.nodes_topological_order()] - - @property - def nodes(self): - """ - Get all nodes - :return: list[str] - """ - return list(self._graph.nodes) - - @property - def nodes_meta(self): - """ - Get all node metas - :return: list[NodeBase] - """ - return [data[1]['meta'] for data in self._graph.nodes.data()] - - @property - def edges(self): - """ - Get all edges - :return: list[(str, str)] - """ - return list(self._graph.edges) - - # - # Path - # - def has_path(self, src: str, target: str) -> bool: - """ - Return True is a path exists from src to target - """ - return nx.has_path(self._graph, src, target) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py deleted file mode 100644 index 9d453b1f4c41d002297c5348cbed8fd7f0ef3081..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py +++ /dev/null @@ -1,324 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Layout algebras -""" - -from pycute import Layout, composition, make_layout, flatten, product - - -def _infer_split(old_shape, new_shape): - old_shape = _tuple_to_list(old_shape) - new_shape = _tuple_to_list(new_shape) - if len(old_shape) == 0 and len(new_shape) == 0: - return [] - if len(old_shape) == 0: - if product(tuple(new_shape)) != 1: - raise ValueError("Invalid reshape size") - else: - return new_shape - if len(new_shape) == 0: - if product(tuple(old_shape)) != 1: - raise ValueError("Invalid reshape size") - else: - return old_shape - # This is done recursively by only process the last dimension at each time - old_dim = old_shape[-1] - new_dim = new_shape[-1] - # Exact match - if old_dim == new_dim: - return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,] - # Needs split - if old_dim > new_dim and old_dim % new_dim == 0: - residual = old_dim // new_dim - return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,] - # Needs merge - if old_dim < new_dim and new_dim % old_dim == 0: - residual = new_dim // old_dim - return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,] - - raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}") - -def _infer_merge(flatten_shape, shape): - flatten_shape = _tuple_to_list(flatten_shape) - shape = _tuple_to_list(shape) - idx_flat = 0 - merged_shape = [] - for dim in shape: - # Exact match - if dim == flatten_shape[idx_flat]: - merged_shape.append(dim) - idx_flat += 1 - # Need group - elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: - residual = dim - group = [] - while(residual > 1): - group.append(flatten_shape[idx_flat]) - residual = residual // flatten_shape[idx_flat] - idx_flat += 1 - merged_shape.append(group) - else: - raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") - - return merged_shape - -def _list_to_tuple(nested_list): - if isinstance(nested_list, list) or isinstance(nested_list, tuple): - return tuple(_list_to_tuple(item) for item in nested_list) - return nested_list - -def _tuple_to_list(nested_tuple): - if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple): - return list(_tuple_to_list(item) for item in nested_tuple) - return nested_tuple - -def _reverse_tuple(nested_tuple: tuple): - if isinstance(nested_tuple, tuple): - return tuple([_reverse_tuple(item) for item in nested_tuple][::-1]) - return nested_tuple - -def _get_first_lhs_nonzero_stride(stride_list, idx): - for i in reversed(range(idx)): - if stride_list[i] != 0: - return i - else: - return None - -def _get_first_rhs_nonzero_stride(stride_list, idx): - for i in range(idx+1, len(stride_list)): - if stride_list[i] != 0: - return i - else: - return None - -def reshape(layout, new_shape): - """ - General reshape of input layout. - It takes two steps: - 1. split the dimensions of the old layout - 2. merge the splitted dimensions according to the new shape - """ - # - # Step 1: Split the dimensions of the old layout - # - # 1.1 Flat old and new shape - old_flatten_shape = list(flatten(layout.shape)) - new_flatten_shape = list(flatten(new_shape)) - - # 1.2 Infer the flatten splitted shape - splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape) - - # 1.3 Unflat the splitted shape based on the old shape - splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape) - - # 1.4 Infer the type of each split - # If the split type is in row-major (R), the dimension list is reversed because - # the cute::composition only support column-major split - split_type = [] # the type of each split (ColumnMajor or RowMajor) - permuted_splitted_shape = [] - old_flatten_stride = list(flatten(layout.stride)) - for idx, dim in enumerate(splited_shape): - if not isinstance(dim, list): - permuted_splitted_shape.append(dim) - split_type.append("C") - else: - lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx) - rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx) - # Special case for single tuple - # Use column-major by default - if lhs_stride is None and rhs_stride is None: - permuted_splitted_shape.append(dim) - split_type.append("C") - else: - if lhs_stride is not None and rhs_stride is not None: - # We consider shape[idx]:stride[idx] - # Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major - if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: - permuted_splitted_shape.append(dim) - split_type.append("C") - # Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major - elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: - permuted_splitted_shape.append([d for d in reversed(dim)]) - split_type.append("R") - # Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave - elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: - if lhs_stride >= rhs_stride: - permuted_splitted_shape.append(dim) - split_type.append("C") - else: - permuted_splitted_shape.append([d for d in reversed(dim)]) - split_type.append("R") - # Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave - elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: - if lhs_stride >= rhs_stride: - permuted_splitted_shape.append(dim) - split_type.append("C") - else: - permuted_splitted_shape.append([d for d in reversed(dim)]) - split_type.append("R") - else: - raise NotImplementedError() - elif lhs_stride is None: - # Case 1: dim's stride < dim+1's stride, expand in column major - if old_flatten_stride[idx] > rhs_stride: - permuted_splitted_shape.append([d for d in reversed(dim)]) - split_type.append("R") - else: - permuted_splitted_shape.append(dim) - split_type.append("C") - else: - # Case 1: dim's stride > dim-1's stride - if old_flatten_stride[idx] < lhs_stride: - permuted_splitted_shape.append([d for d in reversed(dim)]) - split_type.append("R") - else: - permuted_splitted_shape.append(dim) - split_type.append("C") - - # 1.4 Generate the splitted layout - permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape))) - - # 1.5 Reverse the permutation in 1.4 before merge - splitted_shape = [] - splitted_stride = [] - for shape_dim, stride_dim, type in zip( - permuted_splitted_layout.shape, - permuted_splitted_layout.stride, - split_type): - if type == "C": - splitted_shape.append(shape_dim) - splitted_stride.append(stride_dim) - else: - splitted_shape.append(tuple([d for d in reversed(shape_dim)])) - splitted_stride.append(tuple([d for d in reversed(stride_dim)])) - splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride)) - - - # - # Step 2: Merge the splitted dimensions according to the new shape - # - # 2.1 Merge layout - merged_layout = composition(splitted_layout, Layout(new_shape)) - - # 2.2 Cleaning up - output_layout = composition(merged_layout, Layout(new_shape)) - return output_layout - - -def permutation(layout, permutation): - """ - Permute the layout - """ - new_shape = tuple([layout.shape[idx] for idx in permutation]) - new_stride = tuple([layout.stride[idx] for idx in permutation]) - return Layout(new_shape, new_stride) - - -def _broadcast(layout, new_shape): - if len(layout) == 1 and isinstance(new_shape, int): - old_dim = layout.shape - old_stride = layout.stride - new_dim = new_shape - if old_dim == new_dim: - return Layout(old_dim, old_stride) - elif old_dim == 1: - return Layout(new_dim, 0) - else: - raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}") - - # Align the dimensions - old_shape = layout.shape - if isinstance(old_shape, int): - old_shape = (old_shape,) - sub_layouts = [layout,] - else: - sub_layouts = [sub_layout for sub_layout in layout] - rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape)) - # Get the broadcasted layout - broadcast_layouts = [] - try: - layout = make_layout(*sub_layouts, *rhs_broadcast_layouts) - broadcast_layouts = [] - for idx, sub_layout in enumerate(layout): - broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) - except NotImplementedError: - layout = make_layout(*rhs_broadcast_layouts, *sub_layouts) - for idx, sub_layout in enumerate(layout): - broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) - return make_layout(*broadcast_layouts) - - -def broadcast(layout, new_shape): - """ - Broadcast the new layout based on the input shape - The broadcasted shape equals to the new shape - The stride of broadcasted dimensions are 0 - """ - return _broadcast(layout, new_shape) - - -def debroadcast(layout, dims): - """ - Squeeze the 0-stride - """ - for dim in dims: - if layout.stride[dim] != 0: - raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}") - new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims]) - new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims]) - return Layout(new_shape, new_stride) - - -def canonicalization_(shapes, strides): - if isinstance(shapes, tuple): - c_shapes = [] - c_strides = [] - for shape, stride in zip(shapes, strides): - c_shape, c_stride = canonicalization_(shape, stride) - c_shapes.append(c_shape) - c_strides.append(c_stride) - return tuple(c_shapes), tuple(c_strides) - else: - if shapes == 1: - return 1, 0 - else: - return shapes, strides - -def canonicalization(layout): - """ - Canonicalize the input layout - 1. set the stride of shape "1" to 0 - """ - new_shape, new_stride = canonicalization_(layout.shape, layout.stride) - return Layout(new_shape, new_stride) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py deleted file mode 100644 index 1095e2ab1d956399b5e27ddaf140e53d9918ec26..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py +++ /dev/null @@ -1,336 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Layout manipulation nodes and implementations - -The layout Nodes change the layout of intermediate nodes in epilogue visitor graph -""" - -from copy import deepcopy - -from cutlass_library import LayoutType -from pycute import product, flatten - -import cutlass_cppgen -from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list -from cutlass_cppgen.backend.evt.ir.node import NodeBase -from cutlass_cppgen.backend.evt.ir.tensor import Tensor - - -class PermutationImpl: - """ - Detailed implementation and helper functions for permutation - """ - def __init__(self, node) -> None: - assert "indices" in node.kwargs.keys() - self.indices = list(node.kwargs["indices"]) - self.inverse_indices = self.get_inverse_indices(self.indices) - - def get_inverse_impl(self): - inverse_impl = deepcopy(self) - inverse_impl.indices = self.inverse_indices - inverse_impl.inverse_indices = self.indices - return inverse_impl - - def update(self, shape): - num_dim = len(shape) - indices = self.indices - num_old_dim = len(indices) - # Add offset - for i, idx in enumerate(indices): - indices[i] = idx + num_dim - num_old_dim - # Add broadcast dims - for i in range(num_dim - num_old_dim): - indices = [i,] + indices - - self.indices = indices - self.inverse_indices = self.get_inverse_indices(self.indices) - - def get_inverse_indices(self, indices): - """ - Get the indices for inverse permutation - """ - num_dim = len(indices) - inverse_indices = [0] * num_dim - for i in range(num_dim): - inverse_indices[indices[i]] = i - return inverse_indices - - def shape_propagation(self, input_node_meta): - input_shape = input_node_meta.tensor.shape - output_shape = tuple([input_shape[idx] for idx in self.indices]) - return output_shape - - def broadcast(self, shape, node_meta: NodeBase): - """ - Broadcast the inputs based on current shape - """ - self.update(shape) - inverse_shape = tuple([shape[idx] for idx in self.inverse_indices]) - node_meta.tensor.broadcast(inverse_shape) - - def apply_to_user(self, usr_meta: NodeBase): - """ - Propagate the permutation to the users of the current nodes - """ - usr_meta.tensor.permute(self.inverse_indices) - if hasattr(usr_meta, "store_tensor"): - if usr_meta.store_tensor is not None: - usr_meta.store_tensor.permute(self.inverse_indices) - - def apply_to_input(self, input_meta: NodeBase): - """ - Propagate the permutation to inputs of the current nodes - """ - input_meta.tensor.permute(self.indices) - if hasattr(input_meta, "store_tensor"): - if input_meta.store_tensor is not None: - input_meta.store_tensor.permute(self.indices) - - -class ReshapeImpl: - """ - Detailed implementation and helper functions for reshape - """ - def __init__(self, node) -> None: - self.node = node - assert "new_shape" in node.kwargs.keys() - self.output_shape = _list_to_tuple(node.kwargs["new_shape"]) - - def get_inverse_impl(self): - inverse_impl = deepcopy(self) - inverse_impl.output_shape = self.input_shape - inverse_impl.input_shape = self.output_shape - return inverse_impl - - def shape_propagation(self, input_node_meta): - self.input_shape = input_node_meta.tensor.shape - return _list_to_tuple(self.output_shape) - - def broadcast(self, shape, node_meta: NodeBase): - """ - Broadcast the inputs based on current shape. - """ - # Step 1: infer split - flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape)) - split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape) - split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape) - - # broadcast shape -> split_output_shape -> flatten_split_shape - if len(shape) - len(split_output_shape) > 0: - for _ in range(len(shape) - len(split_output_shape)): - split_output_shape = [1,] + split_output_shape - flatten_split_shape = [1,] + flatten_split_shape - split_input_shape = [1,] + split_input_shape - broadcast_factor = [] - for dim, old_dim in zip(shape, split_output_shape): - if not isinstance(dim, list): - dim = [dim,] - if not isinstance(old_dim, list): - old_dim = [old_dim,] - if product(tuple(dim)) == product(tuple(old_dim)): - broadcast_factor += [1] * len(old_dim) - elif product(tuple(old_dim)) == 1: - assert len(dim) == 1 - broadcast_factor.append(dim[0]) - else: - raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}") - - # flatten_split_shape -> split_input_shape - factor_idx = 0 - broadcast_split_input_shape = [] - for dim in split_input_shape: - if isinstance(dim, list): - new_dim = [] - for d in dim: - new_dim.append(d * broadcast_factor[factor_idx]) - factor_idx += 1 - broadcast_split_input_shape.append(new_dim) - else: - broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx]) - factor_idx += 1 - broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape) - node_meta.tensor.reshape(_list_to_tuple(split_input_shape)) - node_meta.tensor.broadcast(broadcast_split_input_shape) - # Last reshape op to clean up - broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape]) - node_meta.tensor.reshape(broadcast_input_shape) - # Update the input shape and output shape - self.input_shape = _list_to_tuple(node_meta.tensor.shape) - self.output_shape = _list_to_tuple(shape) - - def apply_to_user(self, user_meta: NodeBase): - """ - Propagate the reshape to user nodes - """ - user_meta.tensor.reshape(tuple(self.input_shape)) - if hasattr(user_meta, "store_tensor"): - if user_meta.store_tensor is not None: - user_meta.store_tensor.reshape(tuple(self.input_shape)) - - def apply_to_input(self, input_meta: NodeBase): - """ - Propagate the reshape to input nodes - """ - input_meta.tensor.reshape(tuple(self.output_shape)) - if hasattr(input_meta, "store_tensor"): - if input_meta.store_tensor is not None: - input_meta.store_tensor.reshape(tuple(self.output_shape)) - - # - # Helper functions - # - - def infer_split(self, input_shape, output_shape): - """ - Infer the flatten splitted shape that can be merged to both input_shape and output_shape - """ - input_shape = _tuple_to_list(input_shape) - output_shape = _tuple_to_list(output_shape) - if len(input_shape) == 0 and len(output_shape) == 0: - return [] - if len(input_shape) == 0: - if product(tuple(output_shape)) != 1: - raise ValueError("Invalid reshape size") - else: - return output_shape - if len(output_shape) == 0: - if product(tuple(input_shape)) != 1: - raise ValueError("Invalid reshape size") - else: - return input_shape - # This is done recursively by only process the last dimension at each time - old_dim = input_shape[-1] - new_dim = output_shape[-1] - # Exact match - if old_dim == new_dim: - return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,] - # Needs split - if old_dim > new_dim and old_dim % new_dim == 0: - residual = old_dim // new_dim - return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,] - # Needs merge - if old_dim < new_dim and new_dim % old_dim == 0: - residual = new_dim // old_dim - return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,] - - raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}") - - def infer_merge(self, flatten_shape, shape): - flatten_shape = _tuple_to_list(flatten_shape) - shape = _tuple_to_list(shape) - idx_flat = len(flatten_shape) - 1 - merged_shape = [] - for dim in reversed(shape): - # Exact match - if dim == flatten_shape[idx_flat]: - merged_shape.append(dim) - idx_flat -= 1 - # need group - elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: - residual = dim - group = [] - while(residual > 1): - group.append(flatten_shape[idx_flat]) - residual = residual // flatten_shape[idx_flat] - idx_flat -= 1 - merged_shape.append(group[::-1]) - else: - raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") - - return merged_shape[::-1] - - -class LayoutNode(NodeBase): - """ - Layout manipulation nodes - """ - fn_to_impl = { - "permute": PermutationImpl, - "reshape": ReshapeImpl - } - def __init__(self, name: str, fn, kwargs: dict) -> None: - super().__init__(name) - self.op = "layout" - self.fn = fn - self.kwargs = kwargs - self.underlying_impl = self.fn_to_impl[self.fn.__name__](self) - - def get_inverse_node(self): - inverse_node = deepcopy(self) - inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl() - return inverse_node - - def shape_propagation(self, input_node_metas): - if self._tensor is not None: - return - assert len(input_node_metas) == 1, "Layout node can only have one input node" - - output_shape = self.underlying_impl.shape_propagation(input_node_metas[0]) - - self._tensor = Tensor( - element=self.element_output, - shape=output_shape, layout_tag=LayoutType.RowMajor - ) - - return super().shape_propagation(input_node_metas) - - def type_propagation(self, input_node_metas: 'list[NodeBase]'): - """ - The store nodes has element_output = element_input - """ - assert len(input_node_metas) == 1, "Layout node can only have one input node" - self.element_output = input_node_metas[0].element_output - - def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): - """ - Propagate the broadcast in the reversed topological order - """ - if self.tensor is None: - raise RuntimeError(f"The tensor of node {self.name} is unknown.") - shape = self.tensor.shape - - for child in input_node_metas: - self.underlying_impl.broadcast(shape, child) - - def apply_to_user(self, usr_meta: NodeBase): - """ - Propagate the permutation to user nodes - """ - self.underlying_impl.apply_to_user(usr_meta) - - def apply_to_input(self, input_meta: NodeBase): - """ - Propagate the permutation to input nodes - """ - self.underlying_impl.apply_to_input(input_meta) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py deleted file mode 100644 index bff0aaa2c21ef2545f50745cdb33499270eeb9fb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py +++ /dev/null @@ -1,294 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Load nodes and implementations -""" - -import ctypes - -from cutlass_cppgen.backend.c_types import tuple_factory -from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value -from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase - - -class LoadImplBase(ImplBase): - """ - Base class for load node implementations - """ - reserved_names = ["accum", "C"] - def __init__(self, node) -> None: - super().__init__(node) - self.element = node.element - self.element_output = node.element_output - self.stride = node.tensor.stride - - -class AccumulatorImpl(LoadImplBase): - """ - Accumulator node implementation - """ - - @staticmethod - def match(node, problem_size: tuple): - return node.name == "accum" and node.tensor.shape == problem_size - - -class LoadSrcImpl(LoadImplBase): - """ - Load C implementation - """ - @property - def name_camel(self) -> str: - return "TensorC" - - @property - def argument_type_c(self): - stride_mnl = self.get_stride_mnl() - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_C", ctypes.c_void_p), - ("stride_C", tuple_type) - ] - def __init__(self, ptr) -> None: - self.ptr_C = ptr - self.stride_C = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - return node.name == "C" and node.tensor.shape == problem_size - - -class AuxLoadImpl(LoadImplBase): - """ - Load arbitrary tensor - """ - @property - def argument_type(self): - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - element_type = self.element - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_aux", ctypes.c_void_p), - ("null_default", dtype2ctype[element_type]), - ("dAux", tuple_type) - ] - def __init__(self, kwargs) -> None: - ptr = kwargs[name] - self.ptr_aux = ptr - self.null_default = to_ctype_value(0, element_type) - self.dAux = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if node.name in LoadImplBase.reserved_names: - return False - strideMN = node.tensor.stride[-2:] - if (strideMN[0] == 1 and strideMN[1] != 0 or - strideMN[0] != 0 and strideMN[1] == 1 ): - return True - else: - return False - - -class RowBroadcastImpl(LoadImplBase): - """ - Broadcast a row vector - """ - def __init__(self, node) -> None: - super().__init__(node) - self.stride_dtype = "int" - - @property - def argument_type(self): - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - element_type = self.element - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_row", ctypes.c_void_p), - ("null_default", dtype2ctype[element_type]), - ("dRow", tuple_type) - ] - def __init__(self, kwargs) -> None: - ptr = kwargs[name] - self.ptr_row = ptr - self.null_default = to_ctype_value(0, element_type) - self.dRow = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if node.name in LoadImplBase.reserved_names: - return False - - strideMN = node.tensor.stride[-2:] - if strideMN == (0, 1): - return True - else: - return False - - -class ColumnBroadcastImpl(LoadImplBase): - """ - Broadcast a column vector - """ - def __init__(self, node) -> None: - super().__init__(node) - self.stride_dtype = "int" - - @property - def argument_type(self): - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - element_type = self.element - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_col", ctypes.c_void_p), - ("null_default", dtype2ctype[element_type]), - ("dCol", tuple_type) - ] - def __init__(self, kwargs) -> None: - ptr = kwargs[name] - self.ptr_col = int(ptr) - self.null_default = to_ctype_value(0, element_type) - self.dCol = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if node.name in LoadImplBase.reserved_names: - return False - - strideMN = node.tensor.stride[-2:] - if strideMN == (1, 0): - return True - else: - return False - - -class ScalarBroadcastImpl(LoadImplBase): - """ - Broadcast a scalar - """ - def __init__(self, node) -> None: - super().__init__(node) - self.stride_dtype = "int" - - @property - def argument_type(self): - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - element_type = self.element - - if self.tensor.is_constant: - value = self.tensor.value - class _Argument(ctypes.Structure): - _fields_ = [ - ("scalars", dtype2ctype[element_type]), - ("scalar_ptrs", ctypes.c_void_p), - ("dScalar", tuple_type) - ] - def __init__(self, kwargs) -> None: - self.scalars = to_ctype_value(value, element_type) - self.scalar_ptrs = 0 - self.dScalar = tuple_type(stride_mnl) - - else: - class _Argument(ctypes.Structure): - _fields_ = [ - ("scalars", dtype2ctype[element_type]), - ("scalar_ptrs", ctypes.c_void_p), - ("dScalar", tuple_type) - ] - def __init__(self, kwargs) -> None: - scalar_or_ptr = kwargs[name] - if isinstance(scalar_or_ptr, float): - self.scalars = to_ctype_value(scalar_or_ptr, element_type) - self.scalar_ptrs = 0 - else: - self.scalar_ptrs = int(scalar_or_ptr) - - self.dScalar = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if node.name in LoadImplBase.reserved_names: - return False - - strideMN = node.tensor.stride[-2:] - if strideMN == (0, 0): - return True - else: - return False - - -class LoadNode(NodeBase): - """ - Load Node - """ - cnt = 0 - possible_impls = [ - AccumulatorImpl, LoadSrcImpl, AuxLoadImpl, - RowBroadcastImpl, ColumnBroadcastImpl, - ScalarBroadcastImpl - ] - def __init__(self, name: str) -> None: - if name is None: - name = f"load{LoadNode.cnt}" - LoadNode.cnt += 1 - super().__init__(name) - self.op = "load" - - def type_propagation(self, *args, **kwargs): - """ - Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. - """ - if self.tensor is None: - raise RuntimeError(f"The tensor of node {self.name} is unknown.") - - self.element = self.tensor.element - self.element_output = self.tensor.element diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py deleted file mode 100644 index 606591b8e78c97114b85b329050d630d55460d7a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py +++ /dev/null @@ -1,306 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Base & visitor classes of DAGIR Nodes -""" - -import ctypes -from re import sub - -from cutlass_library import LayoutType - -from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple -from cutlass_cppgen.backend.evt.ir.tensor import Tensor - - -class TupleEmitter: - """ - Emit the cute tuple to C++ code - """ - def __init__(self, stride_dtype): - self.stride_dtype = stride_dtype - - def emit(self, py_tuple): - if isinstance(py_tuple, int): - if py_tuple in [0, 1]: - return f"cute::Int<{py_tuple}>" - else: - return f"{self.stride_dtype}" - elif isinstance(py_tuple, tuple): - decl = "cute::Stride<" - for item in py_tuple: - decl += self.emit(item) + ", " - return decl[:-2] + ">" - else: - raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") - - -class ImplBase: - """ - Base class for Node Implementation - """ - def __init__(self, node) -> None: - self.node = node - self.name = node.name - self.tensor = node.tensor - self._type_decl = None - self.tuple_emitter = TupleEmitter("int64_t") - - @property - def stride_dtype(self): - return self.tuple_emitter.stride_dtype - - @stride_dtype.setter - def stride_dtype(self, stride_dtype): - self.tuple_emitter.stride_dtype = stride_dtype - - @staticmethod - def match(node, problem_size: tuple): - """ - Match function used in get_underlying_impl - """ - raise NotImplementedError(f"The `match` function is not defined.") - - @property - def argument_type(self): - """ - Default class for Argument Type - """ - class _Argument(ctypes.Structure): - _fields_ = [] - - def __init__(self, *args, **kwargs) -> None: - pass - - return _Argument - - @property - def name_camel(self) -> str: - """ - Return the CamelCase name. - """ - return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") - - @property - def stride_mnl(self): - """ - Typename StrideMNL - """ - stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) - return self.tuple_emitter.emit(stride) - - def get_non_constant_stride(self, py_tuple): - if isinstance(py_tuple, int): - if py_tuple not in [0, 1]: - return py_tuple - else: - return None - non_constant_stride = [] - for item in py_tuple: - item_out = self.get_non_constant_stride(item) - if item_out: - non_constant_stride.append(item_out) - return tuple(non_constant_stride) - - def get_stride_mnl(self): - """ - Get the non-zero stride mnl. This is used in argument construction - """ - stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) - return stride - - def get_smem_size(self, *args, **kwargs): - """ - Get the shared memory size and alignment of current node - """ - return (0, 1) - - -class NoOpImpl(ImplBase): - """ - The NoOpImpl does nothing but forward its input to users - """ - def __init__(self, node) -> None: - super().__init__(node) - - @staticmethod - def match(node, problem_size: tuple): - if node.op == "store": - # Store that is not output is a No OP - return not node.is_output - - -class NodeBase: - """ - Base class of DAG Node - """ - def __init__(self, name: str) -> None: - self.name = name - self.underlying_impl = None - - self._tensor = None - - # Whether the node is disabled for emit - self.disabled = False - - @property - def name_camel(self) -> str: - """ - Return the CamelCase name. - """ - return self.underlying_impl.name_camel - - @property - def tensor(self) -> Tensor: - """ - Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) - """ - return self._tensor - - @tensor.setter - def tensor(self, kwargs): - """ - Setting the tensor - """ - self._tensor = Tensor(**kwargs) - - # - # Helper functions for type/shape propagation - # - - def shape_propagation(self, input_node_metas): - """ - Infer shape from input nodes - General Broadcasting Rules from NumPy - When operating on two arrays, we compare their shapes element-wise. - It starts with the trailing (i.e. rightmost) dimension and works its - way left. Two dimensions are compatible when - 1. they are equal - 2. one of them is 1 - """ - if self._tensor is not None: - return - - shape = None - for src in input_node_metas: - src_shape = src.tensor.shape - if shape is None: - shape = src_shape - else: - len_difference = len(shape) - len(src_shape) - if len_difference > 0: - for _ in range(len_difference): - src_shape = [1, ] + list(src_shape) - elif len_difference < 0: - for _ in range(-len_difference): - shape = [1, ] + list(shape) - broadcasted_shape = [] - # Infer broadcast shape - for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)): - if shape_dim == 1: - broadcasted_shape = [src_dim, ] + list(broadcasted_shape) - elif src_dim == 1: - broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) - elif shape_dim == src_dim: - broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) - else: - error_msg = "Dimension mismatch between " - for src_ in input_node_metas: - error_msg += f"{src_.name}{src_.tensor.shape}, " - error_msg = error_msg[:-2] + "." - raise RuntimeError(error_msg) - shape = tuple(broadcasted_shape) - - self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor) - - def type_propagation(self, *args, **kwargs): - """ - Each node is associated with two data types: `element` and `element_output`. - The `element_output` is the type of return array of the node. The `element` - has specific meaning for different node types. - * Load Node: data type of tensor in gmem - * Compute Node: element compute - * Store Node: data type of tensor in gmem - This function must be overloaded in the derived classes - """ - raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}") - - def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): - """ - Propagate the broadcast in the reversed topological order. - For example: - C[l, m, n] = A[m, 1] + B[l, m, n] - After the broadcast propagation, it will be come - C[l, m, n] = A[l, m, n] + B[l, m, n] - and each tensor will have a proper stride accessing the underlying tensor - """ - if self.tensor is None: - raise RuntimeError(f"The tensor of node {self.name} is unknown.") - for child in input_node_metas: - child.tensor.broadcast(self.tensor.shape) - - def get_underlying_impl(self, problem_size: tuple): - """ - Get the underlying implementation of the current node. - """ - if self.tensor is None: - raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.") - - for impl in self.possible_impls: - if impl.match(self, problem_size): - self.underlying_impl = impl(self) - break - - if self.underlying_impl is None: - raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.") - -# -# Visitor Nodes & Impls -# - -class TopoVisitorImpl(ImplBase): - """ - Impl for topological visitor - """ - def __init__(self, node) -> None: - super().__init__(node.output_node) - self.name = node.name - self.element_output = node.output_node.element_output - -class TopoVisitorNode(NodeBase): - def __init__(self, name: str, subgraph, output_node) -> None: - super().__init__(name) - self.subgraph = subgraph - self.output_node = output_node - self.op = "dag" - self.underlying_impl = TopoVisitorImpl(self) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py deleted file mode 100644 index 708405e0647ca3cb22bd0c1d4770d71810a469e2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py +++ /dev/null @@ -1,277 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Store node and implementations -""" - -import ctypes - -from cutlass_library import DataType - -from cutlass_cppgen.backend.c_types import tuple_factory -from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value -from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl -from cutlass_cppgen.backend.evt.ir.tensor import Tensor -from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp - - -class StoreImplBase(ImplBase): - """ - Base class for store node implementation - """ - reserved_names = ["D"] - def __init__(self, node) -> None: - super().__init__(node) - self.element = node.element - self.element_output = node.element_output - self.stride = node.store_tensor.stride - - -class StoreDImpl(StoreImplBase): - """ - Store D implementation - """ - - @property - def argument_type_d(self): - stride_mnl = self.get_stride_mnl() - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_D", ctypes.c_void_p), - ("stride_D", tuple_type) - ] - def __init__(self, ptr: int) -> None: - self.ptr_D = ptr - self.stride_D = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if node.name == "D" and node.store_tensor.shape == problem_size: - return True - return False - - -class AuxStoreImpl(StoreImplBase): - def __init__(self, node) -> None: - super().__init__(node) - self.round_style = FloatRoundStyle.ToNearest - - @property - def argument_type(self): - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr_aux", ctypes.c_void_p), - ("dAux", tuple_type) - ] - def __init__(self, kwargs) -> None: - ptr = kwargs[name] - self.ptr_aux = ptr - self.dAux = tuple_type(stride_mnl) - - return _Argument - - @staticmethod - def match(node, problem_size: tuple): - if not node.is_output: - return False - if node.name in StoreImplBase.reserved_names: - return False - - strideMN = node.store_tensor.stride[-2:] - if (strideMN[0] == 1 and strideMN[1] != 0 or - strideMN[0] != 0 and strideMN[1] == 1 ): - return True - else: - return False - - -class ReductionImplBase(StoreImplBase): - def __init__(self, node) -> None: - super().__init__(node) - self.element = node.store_tensor.element - self.element_compute = node.element_compute - self.reg_reduce_fn = self.node.reg_reduce_fn - self.gmem_reduce_fn = self.node.gmem_reduce_fn - self.round_style = node.round_style - self.stride_dtype = "int" - - def get_reduce_identity(self): - """ - Return the reduction identity of the current reduce_fn - """ - maxes = { - DataType.f32: (2 ** 31) - 1, - DataType.f16: (2 ** 15), - DataType.s32: (2 ** 31) - 1, - DataType.s8: (2 ** 7) - 1 - } - mins = { - DataType.f32: -maxes[DataType.f32], - DataType.f16: -maxes[DataType.f16], - DataType.s32: -maxes[DataType.s32], - DataType.s8: -maxes[DataType.s8] - } - if self.reg_reduce_fn == FunctionalOp.Maximum: - if self.element_compute not in mins: - raise Exception(f"No min entry for data type {self.element_compute}") - return to_ctype_value(mins[self.element_compute], self.element_compute) - elif self.reg_reduce_fn == FunctionalOp.Multiplies: - return to_ctype_value(1., self.element_compute) - elif self.reg_reduce_fn == FunctionalOp.Minimum: - if self.element_compute not in maxes: - raise Exception(f"No max entry for data type {self.element_compute}") - return to_ctype_value(maxes[self.element_compute], self.element_compute) - else: - return to_ctype_value(0., self.element_compute) - - @property - def argument_type(self): - self.get_reduce_identity() - stride_mnl = self.get_stride_mnl() - name = self.name - tuple_type = tuple_factory(stride_mnl, self.stride_dtype) - element_compute = self.element_compute - reduce_identity = self.get_reduce_identity() - class _Argument(ctypes.Structure): - _fields_ = [ - ("ptr", ctypes.c_void_p), - ("reduce_identity", dtype2ctype[element_compute]), - ("dMNL", tuple_type) - ] - def __init__(self, kwargs) -> None: - ptr = kwargs[name] - self.ptr = ptr - self.reduce_identity = reduce_identity - self.dMNL = tuple_type(stride_mnl) - - return _Argument - - -class ColumnReductionImpl(ReductionImplBase): - - @staticmethod - def match(node, problem_size: tuple): - if not node.is_output: - return False - if node.name in StoreImplBase.reserved_names: - return False - - strideMN = node.store_tensor.stride[-2:] - if strideMN == (1, 0): - return True - else: - return False - - -class RowReductionImpl(ReductionImplBase): - - @staticmethod - def match(node, problem_size: tuple): - if not node.is_output: - return False - if node.name in StoreImplBase.reserved_names: - return False - - strideMN = node.store_tensor.stride[-2:] - if strideMN == (0, 1): - return True - else: - return False - - -class ScalarReductionImpl(ReductionImplBase): - - @staticmethod - def match(node, problem_size: tuple): - if not node.is_output: - return False - if node.name in StoreImplBase.reserved_names: - return False - - strideMN = node.store_tensor.stride[-2:] - if strideMN == (0, 0): - return True - else: - return False - - -class StoreNode(NodeBase): - """ - Store node - """ - possible_impls = [ - AuxStoreImpl, RowReductionImpl, - ColumnReductionImpl, ScalarReductionImpl, - NoOpImpl, StoreDImpl - ] - def __init__(self, name: str) -> None: - super().__init__(name) - self.op = "store" - self.is_output = False - self._store_tensor = None - - @property - def store_tensor(self) -> Tensor: - """ - Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) - """ - return self._store_tensor - - @store_tensor.setter - def store_tensor(self, kwargs): - """ - Setting the tensor - """ - self._store_tensor = Tensor(**kwargs) - - def type_propagation(self, input_node_metas: 'list[NodeBase]'): - """ - The store nodes has element_output = element_input - """ - if self.is_output: - if self.store_tensor is None: - raise RuntimeError(f"The store tensor of node {self.name} is unknown.") - self.element = self.store_tensor.element - assert len(input_node_metas) == 1, "Store node can only have one input node" - self.element_output = input_node_metas[0].element_output - - def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): - super().broadcast_propagation(input_node_metas) - if self.is_output: - self._store_tensor.broadcast(self.tensor.shape) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py deleted file mode 100644 index 1a28b7306a140d08bd1edebd3486990ea69b9344..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py +++ /dev/null @@ -1,137 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -High-level class for tensor -""" - -from cutlass_library import LayoutType - -from cutlass_cppgen.backend.evt.ir.layout_algorithm import ( - Layout, - broadcast, - canonicalization, - permutation, - reshape, - _reverse_tuple -) -from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type - - -class Tensor: - """ - The tensor abstracts the data type - """ - def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None: - if element is not None and tensor is not None: - raise Exception(f"Must not specify both element and tensor") - elif shape is not None and tensor is not None: - raise Exception(f"Must not specify both shape and tensor") - elif layout_tag is not None and tensor is not None: - raise Exception(f"Must not specify both layout_tag and tensor") - elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) : - raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)") - elif stride is not None and tensor is not None: - raise Exception(f"Must not specify both stride and tensor") - elif stride is not None and layout_tag is not None: - raise Exception(f"Must not specify layout_tag when stride is provided") - - if isinstance(tensor, Tensor): - # Directly copy all the attributes - self.__dict__.update(vars(tensor)) - else: - if tensor is None: - self.element = library_type(element) - else: - self.element, layout_tag = get_datatype_and_layout(tensor) - shape = get_tensor_shape(tensor) - if stride is not None: - self.layout = Layout(shape[::-1], stride[::-1]) - else: - if layout_tag == LayoutType.RowMajor: - self.layout = Layout(shape[::-1]) - elif layout_tag == LayoutType.ColumnMajor: - self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))]) - self.layout = canonicalization(self.layout) - - self.is_constant = is_constant - # Save the tensor value if it is constant - if is_constant and tensor is not None: - self.value = tensor - - @property - def shape(self): - """ - Returns the RowMajor layout shape - """ - return _reverse_tuple(self.layout.shape) - - @property - def stride(self): - """ - Returns the RowMajor layout stride - """ - return _reverse_tuple(self.layout.stride) - - @property - def rank(self): - """ - Returns the rank of the tensor - """ - return len(self.shape) - - # - # Layout Algorithms - # - - def broadcast(self, shape): - """ - Broadcast self.layout to shape - """ - assert isinstance(shape, tuple) - self.layout = broadcast(self.layout, _reverse_tuple(shape)) - - def reshape(self, shape): - """ - Reshape self.layout to shape - """ - assert isinstance(shape, tuple) - reverse_shape = _reverse_tuple(shape) - self.layout = reshape(self.layout, reverse_shape) - - def permute(self, indices): - """ - Permute self.layout according to indices - """ - length = len(indices) - indices = [length - idx - 1 for idx in indices] - self.layout = permutation(self.layout, indices[::-1]) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py deleted file mode 100644 index badc38d96a830992c94afa693ea4b56a8e404c96..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer -from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType -from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree -from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD -from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager -from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed -from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation -from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py deleted file mode 100644 index 8a28c6e4e62d1a7bd7431c81aac366b8788fd8df..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py +++ /dev/null @@ -1,143 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -from __future__ import annotations - -import subprocess - -from cutlass_library import DataTypeTag - -from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR - - -_COLOR_MAP = { - "load": '"AliceBlue"', - "compute": "LemonChiffon1", - "accumulator": "LightGrey", - "store": "PowderBlue", - "layout": "lightseagreen", - "dag": "darkorange" -} - - -class EVTGraphDrawer: - """ - Visualize a EVT DAGIR with graphviz - """ - def __init__( - self, - graph: DAGIR, - name: str - ): - self._name = name - self._dot_graphs = {} - - self._dot_graphs[name] = self._to_dot(graph, name) - - def _get_node_style(self, node): - template = { - "shape": "record", - "fillcolor": "#CAFFE3", - "style": '"filled,rounded"', - "fontcolor": "#000000", - } - if node.op in _COLOR_MAP: - template["fillcolor"] = _COLOR_MAP[node.op] - else: - raise NotImplementedError("unknown node op") - if node.disabled: - template["fontcolor"] = "grey" - template["fillcolor"] = "white" - return template - - def _get_node_label(self, node): - label = "{" + f"name={node.name}|op={node.op}" - if node.op == "layout": - label += f"|fn={node.fn.__name__}" - for key in node.kwargs: - label += f"|{key}={node.kwargs[key]}" - if node.underlying_impl is not None: - label += f"|impl={type(node.underlying_impl).__name__}" - if node.op == "load": - label += f"|element_output={DataTypeTag[node.underlying_impl.element]}" - elif node.op == "compute": - label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" - elif node.op == "store": - label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" - elif node.op == "dag": - label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}" - if node.tensor is not None: - shape = node.tensor.shape - stride = node.tensor.stride - label += f"|shape={shape}|stride={stride}" - - if hasattr(node, "store_tensor"): - if node.store_tensor is not None: - store_shape = node.store_tensor.shape - store_stride = node.store_tensor.stride - label += f"|store_shape={store_shape}|stride_stride={store_stride}" - - label += "}" - return label - - def _to_dot( - self, - graph: DAGIR, - name: str - ): - import pydot - dot_graph = pydot.Dot(name, randir="TB") - for node in graph.nodes_meta: - style = self._get_node_style(node) - label = self._get_node_label(node) - dot_node = pydot.Node( - node.name, label=label, **style - ) - dot_graph.add_node(dot_node) - if node.op == "dag": - dot_subgraph = self._to_dot(node.subgraph, name=node.name) - self._dot_graphs[node.name] = dot_subgraph - - # Add edges - for src, dst in graph.edges: - weight = graph.get_edge_weight(src, dst) - dot_graph.add_edge(pydot.Edge(src, dst, label=weight)) - - return dot_graph - - def get_dot_graph(self) -> pydot.Dot: - return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()] - - def get_dot_graph_by_name(self, name) -> pydot.Dot: - return self._dot_graphs[name] - - def get_main_dot_graph(self) -> pydot.Dot: - return self._dot_graphs[self._name] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py deleted file mode 100644 index b0c3cdbde6d46ad8a7e84c3b95422bdb55e877c5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py +++ /dev/null @@ -1,120 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Construct the epilogue visitor argument type -""" - -from cutlass_cppgen.backend.c_types import visitor_factory -from cutlass_cppgen.backend.evt.ir import TopoVisitorNode -from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree -from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase -from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation -from cutlass_cppgen.backend.evt.passes.util import cc_map - - -class PassGetArgumentType(EVTPassBase): - """ - Construct the epilogue visitor argument type - """ - dependencies = [ - PassShapeTypePropagation, # The Layout of all nodes must be set - PassDAG2Tree, # The type of each node must be set - PassGetImpl # The DAG subgraphs must be set - ] - - def requires(self) -> None: - # Check "D" is in the node list - if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")): - raise SyntaxError( - "Sm90+ EVT requires the epilogue to have a returned tensor D, " - "but the variable 'D' is not found in the return values.") - - def call(self): - nodes = self.dag_ir.nodes_topological_order() - self.argument_types = {} - for node in nodes: - meta = self.dag_ir.get_node_meta(node) - if not meta.disabled: - self.argument_types[node] = meta.underlying_impl.argument_type - if node == "D" and cc_map[self.cc] in [90, 100]: - continue - if isinstance(meta, TopoVisitorNode): - self.get_dag_argument_type(node) - else: - self.get_evt_argument_type(node) - - self.cc_specific_method(self.set_argument_type)() - - def get_evt_argument_type(self, node): - # Sort the input nodes by edge weight - input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)] - if len(input_types) > 0: - self.argument_types[node] = visitor_factory( - input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,]) - - def get_dag_argument_type(self, node): - meta = self.dag_ir.get_node_meta(node) - subgraph = meta.subgraph - subgraph_nodes = subgraph.nodes_topological_order() - # Visit the unvisited nodes in subgraph - for n in subgraph_nodes: - m = subgraph.get_node_meta(n) - if m.disabled: - continue - else: - self.argument_types[n] = m.underlying_impl.argument_type - input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]] - if len(input_types) > 0: - self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1]) - - def set_argument_type(self): - pass - - def sm90_set_argument_type(self): - self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]] - # Get the tensorD argument type - self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d - - # Get the tensorC argument type - if self.dag_ir.has_node("C"): - self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c - else: - self.dag_ir.arg_c_type = self.dag_ir.arg_d_type - - def sm100_set_argument_type(self): - self.sm90_set_argument_type() - - def sm80_set_argument_type(self): - nodes = self.dag_ir.nodes_topological_order() - self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py deleted file mode 100644 index 469769664abdf757319949ab48b4e7d5e982f200..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py +++ /dev/null @@ -1,169 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented -by the topological visitor, while the rest of the graph will be implemented with the tree visitor. -""" - -from copy import deepcopy - -from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode -from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase -from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation - - -class PassDAG2Tree(EVTPassBase): - """ - Convert the DAG IR to Tree by fusing subgraphs - """ - dependencies = [ - PassShapeTypePropagation, - PassGetImpl - ] - - def call(self): - # Step 1: find the nodes that have multiple parents - multi_parent_nodes = [] - - for node in self.dag_ir.nodes_topological_order(): - if self.dag_ir.out_degree(node) > 1: - multi_parent_nodes.append(node) - # Step 2: find the lowest common ancestor (LCA) of all its parents - for node in multi_parent_nodes: - # A multi-parent node could be already fused by the previous node - if not self.dag_ir.has_node(node): - continue - # A node uncovered by the previous fusions can have out degree change - # Case 1: it has <= 1 edges to the previously fused subgraph, no degree change - # Case 2: it has more than one edges to the previously fused subgraph, degree drops - if self.dag_ir.out_degree(node) <= 1: - continue - - # Otherwise, the node still - reachable_nodes = [] - # Complexity: O(Dout*N) - for parent in self.dag_ir.get_users(node): - reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent))) - # get the common reachable objects - common_items = set.intersection(*reachable_nodes) - node_to_fuse = set.union(*reachable_nodes).difference(common_items) - - lca = None - # If common ancestor exists, find the lowest one - if len(common_items) > 0: - topo_order = self.dag_ir.nodes_topological_order() - topo_idx = -1 - for item in common_items: - if lca is None: - lca = item - topo_idx = topo_order.index(item) - else: - if topo_idx > topo_order.index(item): - lca = item - topo_idx = topo_order.index(item) - else: - # there is no common ancestor for all the parents, we pack all the reachable - # nodes into a single DAG node as a fallback. The lca should be the input node of - # one of the output nodes with out_degree = 0 - potential_output_nodes = [] - for node in node_to_fuse: - if self.dag_ir.out_degree(node) == 0: - potential_output_nodes.append(node) - if len(potential_output_nodes) == 0: - raise RuntimeError(f"No output node with out degree = 0 found.") - - output_node = None - if (self.dag_ir.cc >= 90): - # For SM90+, the lca should be the input node of D - if (not self.dag_ir.has_node("D")): - raise RuntimeError(f"D is not a node in the DAG IR.") - output_node = "D" - else: - output_node = potential_output_nodes[0] - - if (output_node is None): - raise RuntimeError(f"No output node found.") - lca = self.dag_ir.get_all_inputs(output_node)[0] - node_to_fuse.remove(output_node) - - # The lca is the output node of the DAG node - # Get the nodes to be fused - node_to_fuse.add(lca) - # Get all the input nodes - all_input_nodes = [] - all_output_nodes = [] - for node in node_to_fuse: - all_input_nodes.append(set(self.dag_ir.get_all_inputs(node))) - all_output_nodes.append(set(self.dag_ir.get_users(node))) - all_input_nodes = set.union(*all_input_nodes) - all_output_nodes = set.union(*all_output_nodes) - - new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes) - - # Create the subgraph - subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) - subgraph = DAGIR(self.dag_ir.cc) - for node in subgraph_.nodes: - meta = deepcopy(self.dag_ir.get_node_meta(node)) - if node not in node_to_fuse: - meta.disabled = True - subgraph.add_node(meta) - for edge in subgraph_.edges: - subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])) - - - # Create the fused node - dag_node = TopoVisitorNode( - name=f"dag_{lca}", subgraph=subgraph, - output_node=self.dag_ir.get_node_meta(lca)) - self.dag_ir.add_node(dag_node) - - # Add input edges - for idx, node in enumerate(all_input_nodes): - self.dag_ir.add_edge(node, dag_node.name, weight=idx) - - # Replace all uses with DAG node (only 1 output node) - self.dag_ir.replace_all_uses_with(lca, dag_node.name) - - # Remove all fused nodes - node_to_fuse.remove(lca) - for node in node_to_fuse: - self.dag_ir.remove_node(node) - - def ensures(self) -> None: - # Ensure that after the pass, the resulting DAG becomes a tree - for node in self.dag_ir.nodes: - out_degree = self.dag_ir.out_degree(node) - if out_degree > 1: - raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py deleted file mode 100644 index 0d57c5b799d125ccc9491760259569731c0bf3ca..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py +++ /dev/null @@ -1,64 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Fix the element_output of producer of D. - -In Sm90 epilogue visitor, the node writing D to gmem does not have internal -element converter, so the compute node producing D must have element_output = type(D). -""" - -from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase - - -class PassFixElementD(EVTPassBase): - """ - In Sm90 epilogue visitor, the node writing D to gmem does not have internal - element converter, so the compute node producing D must have - element_output = type(D) - """ - dependencies = [ - PassLayoutManipulateElimination - ] - def get_producer(self, node, element_D): - node_meta = self.dag_ir.get_node_meta(node) - if node_meta.op == "compute": - node_meta.element_output = element_D - elif node_meta.op == "store": - self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D) - - def call(self): - if self.dag_ir.has_node("D"): - node_d_meta = self.dag_ir.get_node_meta("D") - element_D = node_d_meta.store_tensor.element - self.get_producer("D", element_D) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py deleted file mode 100644 index 90fdafe7d0e80492bd2e641c69f11d95aace6bba..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py +++ /dev/null @@ -1,90 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Infer the underlying implement of each node. - -While the frontend only distinguish between Load/Store/Compute Node, -each of these nodes can have different underlying implementation based -on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. -This pass infers the underlying impl of each node -""" - -import cutlass_cppgen.backend.evt.backend as evt_backend -from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode -from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase -from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination -from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation -from cutlass_cppgen.backend.evt.passes.util import cc_map - - -class PassGetImpl(EVTPassBase): - """ - While the frontend only distinguish between Load/Store/Compute Node, - each of these nodes can have different underlying implementation based - on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. - This pass infers the underlying impl of each node - """ - dependencies = [ - PassShapeTypePropagation, # The shape and type info are required for inference - PassFixElementD - ] - - def __init__(self, dag_ir: DAGIR) -> None: - super().__init__(dag_ir) - self.no_op_elimination = PassNoOpElimination(dag_ir) - - def requires(self) -> None: - # Verify "accum" is in the arg list - if not self.dag_ir.has_node("accum"): - raise SyntaxError("Cannot find 'accum' in the argument list.") - - def call(self): - # The loop structure of the epilogue is determined by the - # accumulator shape - accumulator: LoadNode = self.dag_ir.get_node_meta("accum") - problem_size = accumulator.tensor.shape - - for node_meta in self.dag_ir.node_metas_topological_order(): - node_meta.get_underlying_impl(problem_size) - - def ensures(self) -> None: - # Some nodes will be lowered to NoOp, eliminate them - self.no_op_elimination() - # Lower to cc-specific impl - for node_meta in self.dag_ir.nodes_meta: - node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes") - node_meta.underlying_impl = getattr( - node_impl_ccs, - f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__ - )(node_meta) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py deleted file mode 100644 index af147969f016b50ef05034fca99b173777948622..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py +++ /dev/null @@ -1,217 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Eliminate layout manipulation nodes -""" - -from copy import deepcopy - -from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase -from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation - - -class PassLayoutManipulateElimination(EVTPassBase): - """ - Eliminate layout manipulation nodes - """ - dependencies = [PassShapeTypePropagation] - - def __init__(self, dag_ir: DAGIR) -> None: - super().__init__(dag_ir) - self.copy_cnt = 0 - - def call(self): - self.layout_nodes_worklist = self.get_all_layout_nodes() - # Run while loop utill all layout nodes are eliminated - while(len(self.layout_nodes_worklist) > 0): - node = self.layout_nodes_worklist.pop(0) - # for node in layout_nodes: - # Step 1: get the propagation direction - direction = self.get_propagation_direction(node) - self.visited = [] - getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node) - # Eliminate the current node - input_node = self.dag_ir.get_all_inputs(node)[0] - self.dag_ir.replace_all_uses_with(node, input_node) - # layout_nodes = self.get_all_layout_nodes() - - def get_all_layout_nodes(self): - layout_nodes = [] - for node_meta in reversed(self.dag_ir.node_metas_topological_order()): - if isinstance(node_meta, LayoutNode): - layout_nodes.append(node_meta.name) - return layout_nodes - - def get_propagation_direction(self, node: str): - """ - The logic is propagating all layout nodes away from the accumulator node. - """ - self.visited = [] - self.get_influenced_users(node) - nodes_influenced_dir_users = self.visited - self.visited = [] - self.get_influenced_inputs(node) - nodes_influenced_dir_inputs = self.visited - - if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs: - return "inputs" - elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs: - return "users" - else: - raise RuntimeError("Unsolved propagation direction") - - # Get all influenced nodes if we propagate along the user direction - def get_influenced_users(self, node: str): - if node in self.visited: - return - self.visited.append(node) - - users = self.dag_ir.get_users(node) - for user in users: - self.get_influenced_users(user) - user_inputs = [] - for user in users: - user_inputs.append(set(self.dag_ir.get_all_inputs(user))) - if len(user_inputs) > 0: - user_inputs = set.union(*user_inputs) - user_inputs.remove(node) - for input in user_inputs: - self.get_influenced_inputs(input) - - # Get all influenced nodes if we propagate along the input direction - def get_influenced_inputs(self, node: str): - if node in self.visited: - return - self.visited.append(node) - - inputs = self.dag_ir.get_all_inputs(node) - for input in inputs: - self.get_influenced_inputs(input) - input_users = [] - for input in inputs: - input_users.append(set(self.dag_ir.get_users(input))) - if len(input_users) > 0: - input_users = set.union(*input_users) - input_users.remove(node) - for user in input_users: - self.get_influenced_users(user) - - def add_copy_before(self, layout_node_meta: LayoutNode, target: str): - copied_node_meta = deepcopy(layout_node_meta) - copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" - self.copy_cnt += 1 - copied_node_meta.name = copied_node - self.dag_ir.add_node(copied_node_meta) - # Add edges - target_inputs = self.dag_ir.get_all_inputs(target) - for src in target_inputs: - self.dag_ir.remove_edge(src, target) - self.dag_ir.add_edge(src, copied_node) - self.dag_ir.add_edge(copied_node, target) - self.layout_nodes_worklist.append(copied_node) - - def add_copy_after(self, layout_node_meta: LayoutNode, target: str): - copied_node_meta = deepcopy(layout_node_meta) - copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" - self.copy_cnt += 1 - copied_node_meta.name = copied_node - self.dag_ir.add_node(copied_node_meta) - # Add edges - users = self.dag_ir.get_users(target) - for user in users: - self.dag_ir.remove_edge(target, user) - self.dag_ir.add_edge(copied_node, user) - self.dag_ir.add_edge(target, copied_node) - self.layout_nodes_worklist.append(copied_node) - - # Propagate the layout `node` along the user direction - def propagate_to_users(self, layout_node_meta: LayoutNode, node: str): - """ - Propagate layout node to users - """ - if node in self.visited: - # Avoid applying twice - return - self.visited.append(node) - - node_meta = self.dag_ir.get_node_meta(node) - if layout_node_meta.name != node: - if isinstance(node_meta, LayoutNode): - # Layout node is not transparent with layout node - self.add_copy_before(layout_node_meta, node) - return - else: - layout_node_meta.apply_to_user(node_meta) - - users = self.dag_ir.get_users(node) - user_inputs = [] - for user in users: - user_inputs.append(set(self.dag_ir.get_all_inputs(user))) - for user in users: - self.propagate_to_users(layout_node_meta, user) - if len(user_inputs) > 0: - user_inputs = set.union(*user_inputs) - user_inputs.remove(node) - for input in user_inputs: - self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input) - - # Propagate the layout `node` along the input direction - def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str): - """ - Propagate layout node to inputs - """ - if node in self.visited: - # Avoid applying twice - return - self.visited.append(node) - - node_meta = self.dag_ir.get_node_meta(node) - if layout_node_meta.name != node: - if isinstance(node_meta, LayoutNode): - # Layout node is not transparent with layout node - self.add_copy_after(layout_node_meta, node) - return - else: - layout_node_meta.apply_to_input(node_meta) - inputs = self.dag_ir.get_all_inputs(node) - input_users = [] - for input in inputs: - input_users.append(set(self.dag_ir.get_users(input))) - for input in inputs: - self.propagate_to_inputs(layout_node_meta, input) - if len(input_users) > 0: - input_users = set.union(*input_users) - input_users.remove(node) - for user in input_users: - self.propagate_to_users(layout_node_meta.get_inverse_node(), user) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py deleted file mode 100644 index e8b46bddb06e7c20be6d20526792777edef64b90..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py +++ /dev/null @@ -1,164 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Pass manager for DAG IR. -""" - -from typing import Any - -import networkx as nx - -from cutlass_cppgen.backend.evt.ir import DAGIR -from cutlass_cppgen.backend.evt.passes.util import cc_map - - -class EVTPassBase: - """ - Base class for EVT Passes - """ - dependencies = [] - def __init__(self, dag_ir: DAGIR) -> None: - self.dag_ir = dag_ir - self.cc = self.dag_ir.cc - - def requires(self) -> None: - """ - This function will be called before the pass is run. - """ - pass - - def call(self) -> None: - """ - The pass that is run through the self.dag_ir - """ - raise NotImplementedError( - f"__call__ is not overwritten in Pass {self.__class__.__name__}") - - def ensures(self) -> None: - """ - This function will be called after the pass is run. - """ - pass - - def __call__(self) -> Any: - self.requires() - self.call() - self.ensures() - - def cc_specific_method(self, func): - """ - This enables defining function that behaves differently under different cc - The simplest example of using this function is the following - - .. highlight:: python - .. code-block:: python - - class ExamplePass(EVTPassBase): - - def call(sekf): - # This automatically select the smXX_func based on current cc - self.cc_specific_method(self.func)() - - # Interface func, can be empty - def func(self): - pass - - # Sm90 specific func - def sm90_func(self): - // sm90 specific method - return - - # Sm80 specific func - def sm80_func(self): - // sm80 specific method - return - """ - func_name = f"sm{cc_map[self.cc]}_{func.__name__}" - if hasattr(self, func_name): - return getattr(self, func_name) - else: - raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}") - - -class EVTPassManager(nx.DiGraph): - """ - Topological-based Pass Manager. - Each registered pass has a list of dependencies. The pass manager organizes - the passes as a DAG and launch the compiler passes under topological order. - """ - def __init__(self, dag_ir: DAGIR, pass_list): - super().__init__() - self.dag_ir = dag_ir - for pass_cls in pass_list: - self.add_pass(pass_cls) - - self.sorted_passes = self.schedule() - - def get_callable(self, pass_name): - """ - Return the callable of the pass - """ - return self.nodes[pass_name]["callable"] - - def add_pass(self, pass_cls): - """ - Add a pass to the pass manager - :param pass_cls: the class of pass - :type pass_cls: derived class of EVTPassBase - """ - name = pass_cls.__name__ - pass_callable = pass_cls(self.dag_ir) - self.add_node(name, callable=pass_callable) - - def schedule(self): - """ - Schedule the added passes under topological order - """ - # Add edges - for pass_name in self.nodes: - callable = self.get_callable(pass_name) - for dependency_cls in callable.dependencies: - self.add_edge( - dependency_cls.__name__, - type(callable).__name__) - - # Topological sort - return list(nx.topological_sort(self)) - - def __call__(self) -> Any: - """ - Launch the registered passes - """ - for pass_name in self.sorted_passes: - callable = self.get_callable(pass_name) - callable() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py deleted file mode 100644 index 13107eb1d11c9a436348a4e50a92e62ce6f8b312..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py +++ /dev/null @@ -1,53 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -No op elimination node -""" - -from typing import Any - -from cutlass_cppgen.backend.evt.ir import NoOpImpl -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase - - -class PassNoOpElimination(EVTPassBase): - """ - The dead node elimination pass removes nodes with NoOpImpl in DAG IR - """ - dependencies = [] - - def call(self) -> Any: - for node in self.dag_ir.nodes_topological_order(): - node_meta = self.dag_ir.get_node_meta(node) - if isinstance(node_meta.underlying_impl, NoOpImpl): - self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0]) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py deleted file mode 100644 index 6423a2b845dd643650cf99037178030bee6f0dbd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py +++ /dev/null @@ -1,97 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Preprocess the reduction nodes. - -The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store() -This pass fuses these into a single store node, and then replaces all uses of the -current node with the new store node. -""" - -from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase - - -class PassPreprocessRed(EVTPassBase): - """ - Preprocess red nodes - """ - - def call(self): - # Step 1: find the compute nodes with op=red - red_compute_nodes = [] - for node_meta in self.dag_ir.nodes_meta: - if isinstance(node_meta, ComputeNode): - if type(node_meta.fn) == tuple: - # To keep the frontend simple, the reduction nodes - # are parsed into compute nodes by default - # The simple heuristic to distinguish between compute - # and reduction node is that compute node is a single function, - # while the reduction node is a tuple of functions for - # in-register reduction and atomic global memory reduction - red_compute_nodes.append(node_meta.name) - - # Step 2: for each compute, merge it with the succeeding store - for node in red_compute_nodes: - # Verify - users = self.dag_ir.get_users(node) - inputs = self.dag_ir.get_all_inputs(node) - # Has a single user - assert len(users) == 1 - assert len(inputs) == 1 - user = users[0] - input = inputs[0] - - user_meta = self.dag_ir.get_node_meta(user) - # Must be a store node - assert isinstance(user_meta, StoreNode) - # With output degree == 0 - assert self.dag_ir.out_degree(user) == 0 - # Register the reduce op - node_meta = self.dag_ir.get_node_meta(node) - user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn - user_meta.element_compute = node_meta.element_compute - user_meta.round_style = node_meta.round_style - - # Replace all uses - self.dag_ir.remove_edge(input, node) - input_users = self.dag_ir.get_users(input) - for iu in input_users: - weight = self.dag_ir.get_edge_weight(input, iu) - self.dag_ir.add_edge(user, iu, weight) - self.dag_ir.remove_edge(input, iu) - self.dag_ir.add_edge(input, user) - self.dag_ir.remove_node(node) - - # Register the reduction name - self.dag_ir.reduction_names.append(user) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py deleted file mode 100644 index cb90a82c8f637429d3c64b3d881eb30d02c8c804..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py +++ /dev/null @@ -1,59 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Shape and type propagation pass -""" - -from cutlass_cppgen.backend.evt.ir.node import NodeBase -from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase -from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed - - -class PassShapeTypePropagation(EVTPassBase): - """ - Propagate the shape and type of all nodes - """ - dependencies = [PassPreprocessRed] - - def call(self): - # Propagate the node shape and type - for node in self.dag_ir.nodes_topological_order(): - node_meta: NodeBase = self.dag_ir.get_node_meta(node) - input_node_metas = self.dag_ir.get_all_inputs_meta(node) - node_meta.type_propagation(input_node_metas) - node_meta.shape_propagation(input_node_metas) - - for node in reversed(self.dag_ir.nodes_topological_order()): - node_meta: NodeBase = self.dag_ir.get_node_meta(node) - input_node_metas = self.dag_ir.get_all_inputs_meta(node) - node_meta.broadcast_propagation(input_node_metas) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py deleted file mode 100644 index 8168c59733a5da15eacbbe583c890610655ecff5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py +++ /dev/null @@ -1,319 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Compute the shared memory size in bytes -""" - -from math import gcd - -import cutlass_library -from pycute import flatten, shape_div, product - -import cutlass_cppgen -from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR -from cutlass_cppgen.backend.library import DataType, DataTypeSize - - -class GetSmemSize: - """ - Get the size in byte of shared memory used by the kernel - """ - def __init__(self, dag_ir: DAGIR) -> None: - self.dag_ir = dag_ir - self.cc = self.dag_ir.cc - - # - # Sm90 epilogue specific - # - - def sm90_epilogue_tile(self, tile_description): - # Get the epilogue tile size - schedule = tile_description.epilogue_schedule - if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: - element_d = self.dag_ir.get_node_meta("D").element - nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32 - epi_tile_m = min(64, tile_description.threadblock_shape[0]) - epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) - epilogue_tile_mn = (epi_tile_m, epi_tile_n) - elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: - epi_tile_m = min(128, tile_description.threadblock_shape[0]) - epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) - epilogue_tile_mn = (epi_tile_m, epi_tile_n) - else: - raise NotImplementedError(f"Unsupported schedule: {schedule}") - - # Get the pipeline stages - stages_d = 2 - epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) - if self.dag_ir.has_node("C"): - element_c = self.dag_ir.get_node_meta("C").element - else: - element_c = None - - element_d = self.dag_ir.get_node_meta("D").element - if element_c == element_d: - reuse_smem_c = True - else: - reuse_smem_c = False - stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles - - # Record the epilogue tile - self.cta_tile_mnk = tuple(tile_description.threadblock_shape) - self.epilogue_tile_mn = epilogue_tile_mn - self.epi_tiles = epi_tiles - self.stages_c = stages_c - self.stages_d = stages_d - self.reuse_smem_c = reuse_smem_c - self.element_c = element_c - self.element_d = element_d - self.is_source_supported = element_c is not None - - def sm90_or_sm100_epilogue_smem_size(self, tile_description): - # Get the Fusion Storage - nodes = self.dag_ir.nodes_topological_order() - self.smem_types = {} - for node in nodes: - meta = self.dag_ir.get_node_meta(node) - if not meta.disabled: - self.smem_types[node] = meta.underlying_impl.get_smem_size( - self.cta_tile_mnk, self.epilogue_tile_mn, - self.stages_c, self.stages_d, self.epi_tiles) - if node == "D": - continue - if isinstance(meta, TopoVisitorNode): - self.get_dag_smem_type(node) - else: - self.get_evt_smem_type(node) - - thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0] - # Get the Tensor Storage - tensors = [] - if self.is_source_supported: - smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8 - tensors.append((smem_C, 128)) - else: - tensors.append((0, 1)) - if self.reuse_smem_c: - tensors.append((0, 128)) - else: - smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8 - tensors.append((smem_D, 128)) - tensors.append((thread_smem_size, 128)) - - tensor_smem_size = self.get_struct_size(tensors) - # Get pipeline storage size - # sizeof(uint64_t * stages_c * 2), alignment of uint64_t - # 2 is for FullBarrier and EmptyBarrier - pipeline_smem_size = (8 * self.stages_c * 2, 8) - - # get SharedStorage size - smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size]) - return smem_size[0] - - def sm90_epilogue_smem_size(self, tile_description): - """ - Compute the shared memory size of sm90 collective epilogue - """ - self.sm90_epilogue_tile(tile_description) - return self.sm90_or_sm100_epilogue_smem_size(tile_description) - - # - # Sm100 epilogue specific - # - - def sm100_epilogue_tile(self, tile_description): - cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1]) - mma_tile = cta_tile - - if tile_description.is_2sm: - cta_tile = (cta_tile[0] // 2, cta_tile[1]) - - if tile_description.is_2sm and mma_tile[0] == 128: - tmem_warps = (2, 2) - else: - tmem_warps = (4, 1) - - if self.dag_ir.has_node("C"): - element_c = self.dag_ir.get_node_meta("C").element - element_c_size = DataTypeSize[element_c] - else: - element_c = None - element_c_size = 0 - - element_d = self.dag_ir.get_node_meta("D").element - - DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void - - CtaM = cta_tile[0] - CtaN = cta_tile[1] - WarpM = tmem_warps[0] - WarpN = tmem_warps[1] - MaxBits = max(element_c_size, DataTypeSize[element_d]) - DpFull = 32 - M = min(CtaM, DpFull * WarpM) - - if DisableSource: - # Epilogues w/o residual load are less sensitive to smem allocation - # Target a fixed amount of compute per epilogue iteration - if MaxBits == 4: - # Make epilogue tile larger to reduce the epilogue iterations. - # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. - ComputeElts = 8192 - Nperf = ComputeElts // M - else: - ComputeElts = 4096 - Nperf = ComputeElts // M - else: - # Epilogues w/ residual load are more sensitive to smem allocation - # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize - if MaxBits == 32: - Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32 - elif MaxBits == 16: - Nperf = 32 if CtaN <= 128 else 64 - else: - Nperf = 64 - - def is_m_major(layout): - return flatten(layout.stride[0]) == 1 - - if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout): - N_min_C = 8 * WarpN - elif element_c_size == 6: - N_min_C = 128 * WarpN - else: - N_min_C = (128 // element_c_size) * WarpN - - if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout): - N_min_D = 8 * WarpN - elif DataTypeSize[element_d] == 6: - N_min_D = 128 * WarpN - else: - N_min_D = (128 // DataTypeSize[element_d]) * WarpN - - N = min(CtaN, max(Nperf, N_min_C, N_min_D)) - - tile_m = M - tile_n_size = N // WarpN * WarpN - - epilogue_tile_mn = (tile_m, tile_n_size) - epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) - - stages_d = min(epi_tiles, 2) - reuse_smem_c = (element_c_size > 8) - - if reuse_smem_c: - stages_c = max(min(epi_tiles, 4), stages_d + 1) - else: - stages_c = min(epi_tiles, 4) - - # Record the epilogue tile - self.cta_tile_mnk = tuple(tile_description.threadblock_shape) - self.epilogue_tile_mn = epilogue_tile_mn - self.epi_tiles = epi_tiles - self.stages_c = stages_c - self.stages_d = stages_d - self.reuse_smem_c = reuse_smem_c - self.element_c = element_c - self.element_d = element_d - self.is_source_supported = not DisableSource - - def sm100_epilogue_smem_size(self, tile_description): - """ - Compute the shared memory size of sm100 collective epilogue - """ - self.sm100_epilogue_tile(tile_description) - return self.sm90_or_sm100_epilogue_smem_size(tile_description) - - def __call__(self, tile_description): - return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description) - - # - # Helper functions - # - - @staticmethod - def get_visitor_size(members: list, ebo: bool): - """ - Get the size of struct in bytes - """ - offset = 0 - max_alignment = 1 - if len(members) > 0: - # Get alignment - for _, alignment in members: - max_alignment = max(max_alignment, alignment) - - for type_size, _ in members: - if type_size != 0: - offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment - if type_size == 0 and not ebo: - offset += 1 - else: - offset += type_size - offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment - return (offset, max_alignment) - else: - # Struct size is at least 1 - return (1, 1) - - def get_struct_size(self, members: list): - """ - Get the size of struct in bytes - """ - return self.get_visitor_size(members, False) - - def get_evt_smem_type(self, node): - # Sort the input nodes by edge weight - input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)] - input_types.append(self.smem_types[node]) - if len(input_types) > 1: - ebo = len(input_types) > 4 - self.smem_types[node] = self.get_visitor_size(input_types, ebo) - - def get_dag_smem_type(self, node): - meta = self.dag_ir.get_node_meta(node) - subgraph = meta.subgraph - subgraph_nodes = subgraph.nodes_topological_order() - # Visit the unvisited nodes in subgraph - for n in subgraph_nodes: - m = subgraph.get_node_meta(n) - if m.disabled: - continue - else: - self.smem_types[n] = m.underlying_impl.get_smem_size( - self.cta_tile_mnk, self.epilogue_tile_mn, - self.stages_c, self.stages_d, self.epi_tiles) - input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]] - if len(input_types) > 0: - ebo = len(input_types) > 4 - self.smem_types[node] = self.get_visitor_size(input_types, ebo) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py deleted file mode 100644 index 4b72e330523ca1e4fb8c5d4526289641e158e72e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py +++ /dev/null @@ -1,46 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for passes -""" - -# Map from the CC of the kernel to the EVT implementation that the CC targets -cc_map = { - 80: 80, - 86: 80, - 89: 80, - 90: 90, - 100: 100, - 101: 100, - 103: 100, -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py deleted file mode 100644 index a959976b8601b0793c4c7c1709d61c8c838df838..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py +++ /dev/null @@ -1,109 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -from __future__ import annotations - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -import numpy as np - -from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice -from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor - - -class NumpyFrontend: - """ - Frontend node for numpy - """ - - @staticmethod - def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr: - """Convert the input numpy tensor to CUDA device pointer - - :param np_tensor: input numpy nd array - :param is_output: whether the tensor is output - - :return: CUDA device pointer - """ - # copy the data to device - if is_output: - return device_mem_alloc(np_tensor.size * np_tensor.itemsize) - else: - return todevice(np_tensor) - - -class TorchFrontend: - """ - Frontend node for torch - """ - - @staticmethod - def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr: - """Convert the input torch tensor to CUDA device pointer - - :param torch_tensor: input torch tensor - :param is_output: whether the tensor is output - - :return: CUDA device pointer - """ - - # check the device of torch_tensor - if not torch_tensor.is_cuda: - torch_tensor = torch_tensor.to("cuda") - - return cuda.CUdeviceptr(torch_tensor.data_ptr()) - - -class CupyFrontend: - """ - Frontend node for cupy - """ - - @staticmethod - def argument(cupy_ndarray: "cp.ndarray"): - return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) - - -class TensorFrontend: - """ - Universal Frontend for client-provide tensors - """ - - @staticmethod - def argument(tensor, is_output=False): - if is_numpy_tensor(tensor): - return NumpyFrontend.argument(tensor, is_output) - elif is_torch_tensor(tensor): - return TorchFrontend.argument(tensor) - elif is_cupy_tensor(tensor): - return CupyFrontend.argument(tensor) - else: - raise NotImplementedError("Unknown Tensor Type") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py deleted file mode 100644 index 5e2a3a30a097eb45c691554daf70f8db12e5bc48..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py +++ /dev/null @@ -1,2145 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -from __future__ import annotations - -import copy -import ctypes -import enum - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -from cutlass_library import SubstituteTemplate -import numpy as np - -from cutlass_library import ( - ComplexTransformTag, - DataType, - DataTypeNames, - DataTypeSize, - DataTypeTag, - EpilogueScheduleSuffixes, - EpilogueScheduleTag, - EpilogueScheduleType, - GemmKind, - GemmKindNames, - GemmUniversalMode, - KernelScheduleSuffixes, - KernelScheduleTag, - KernelScheduleType, - LayoutTag, - LayoutType, - MathOperation, - MathOperationTag, - OpcodeClass, - OpcodeClassNames, - OpcodeClassTag, - OperationKind, - ShortComplexLayoutNames, - ShortDataTypeNames, - ShortLayoutTypeNames, - SwizzlingFunctor, - SwizzlingFunctorTag, - TileSchedulerSuffixes, - TileSchedulerTag, - TileSchedulerType, - get_complex_from_real -) -from cutlass_cppgen.backend.arguments import ArgumentBase -from cutlass_cppgen.backend.c_types import ( - GemmCoord_, - GemmCoordBatched_, - GenericMainloopArguments3x_, - StrideBatched_, - dim3_, - get_gemm_arguments, - get_gemm_arguments_3x, - get_gemm_arguments_streamk, - get_gemm_grouped_arguments, - get_mainloop_arguments_3x, - get_tile_scheduler_arguments_3x, -) -from cutlass_cppgen.backend.library import ( - ApiVersion, - EmissionType, - SchedulerMode, - SchedulerModeTag, - TensorDescription, - TileDescription, - api_version, -) -from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice -from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor -from cutlass_cppgen.backend.utils.device import device_sm_count -from cutlass_cppgen.shape import GemmCoord, MatrixCoord - - -################################################################################ -# -# Data structure modeling a GEMM operation -# -################################################################################ - - -def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int: - """ - Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``. - - :param layout: layout of the tensor - :type layout: cutlass_cppgen.shape.LayoutType - :param shape: shape of the tensor - :type shape: cutlass_cppgen.shape.MatrixCoord - - :return: leading dimension of the tensor - :rtype: int - """ - if layout == LayoutType.RowMajor: - return shape.column - elif layout == LayoutType.ColumnMajor: - return shape.row - - -def transpose_layout(layout: LayoutType) -> LayoutType: - if layout == LayoutType.ColumnMajor: - return LayoutType.RowMajor - elif layout == LayoutType.RowMajor: - return LayoutType.ColumnMajor - else: - raise ValueError(f"Unsupported Layout {layout}") - - -class GemmArguments2x(ArgumentBase): - """ - Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and - user-provide tensors into the kernel's argument - - :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | - :class:`cutlass_cppgen.backend.GemmOperationGrouped` - - :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass_cppgen.shape.GemmCoord` - - :param A: tensor A - :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param B: tensor B - :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param C: tensor C - :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param D: tensor D - :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` - - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - """ - - def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): - self.operation = operation - - self.layout_A = operation.A.layout - self.layout_B = operation.B.layout - self.layout_C = operation.C.layout - - self.element_A = operation.A.element - self.element_B = operation.B.element - self.element_C = operation.C.element - - if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]: - raise Exception("Interleaved layout not currently supported") - - if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]: - super().__init__(A, B, None, None, **kwargs) - else: - super().__init__(A, B, C, D, **kwargs) - - if operation.switched: - self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) - self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A - else: - self.problem_size = problem_size - # If the number of elements in C = problem_size.n, C is treated as the bias - if hasattr(self, "tensor_c_numel"): - if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1: - self.bias = True - - self.lda = leading_dimension(self.layout_A, self.problem_size.mk) - self.ldb = leading_dimension(self.layout_B, self.problem_size.kn) - self.ldc = leading_dimension(self.layout_C, self.problem_size.mn) - self.ldd = self.ldc - - if self.bias: - self.ldc = 0 - - if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel: - self.output_op = kwargs["output_op"] - else: - if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: - dtype = int - else: - dtype = float - self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0)) - - self.gemm_mode = gemm_mode - if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: - if "split_k_slices" in kwargs.keys(): - self.batch_count = kwargs["split_k_slices"] - else: - self.batch_count = 1 - self.split_k_slices = self.batch_count - - if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]: - if "batch" in kwargs.keys(): - self.batch_count = kwargs["batch"] - else: - self.batch_count = 1 - - if "batch_strides" in kwargs: - self.batched_stride_A = kwargs["batch_strides"]["A"] - self.batched_stride_B = kwargs["batch_strides"]["B"] - self.batched_stride_C = kwargs["batch_strides"]["C"] - self.batched_stride_D = kwargs["batch_strides"]["D"] - else: - self.batched_stride_A = self.problem_size.m * self.problem_size.k - self.batched_stride_B = self.problem_size.n * self.problem_size.k - self.batched_stride_C = self.problem_size.m * self.problem_size.n - self.batched_stride_D = self.problem_size.m * self.problem_size.n - - if self.bias: - self.batched_stride_C = self.problem_size.n - - if gemm_mode == GemmUniversalMode.Array: - self.ptr_A_array = [] - self.ptr_B_array = [] - self.ptr_C_array = [] - self.ptr_D_array = [] - - ptr_A_addr = int(self.ptr_A) - ptr_B_addr = int(self.ptr_B) - ptr_C_addr = int(self.ptr_C) - ptr_D_addr = int(self.ptr_D) - - stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8 - stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8 - stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8 - stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8 - for _ in range(self.batch_count): - self.ptr_A_array.append(ptr_A_addr) - self.ptr_B_array.append(ptr_B_addr) - self.ptr_C_array.append(ptr_C_addr) - self.ptr_D_array.append(ptr_D_addr) - - ptr_A_addr += stride_A - ptr_B_addr += stride_B - ptr_C_addr += stride_C - ptr_D_addr += stride_D - - self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64) - self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64) - self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64) - self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64) - - if isinstance(self.operation, GemmOperationUniversal): - self.initialize() - - def get_arguments(self): - problem_size_ = self.problem_size.ctype - grid_tiled_shape_ = GemmCoord( - self.grid_tiled_shape.x, - self.grid_tiled_shape.y, - self.grid_tiled_shape.z ).ctype - - if self.gemm_mode == GemmUniversalMode.Array: - arguments = self.operation.argument_type( - # Arguments from UniversalArgumentsBase - self.gemm_mode, - problem_size_, - self.batch_count, - 0, - # Remaining arguments - self.output_op, - int(self.ptr_A_array_buffer.ptr), - int(self.ptr_B_array_buffer.ptr), - int(self.ptr_C_array_buffer.ptr), - int(self.ptr_D_array_buffer.ptr), - 0, 0, 0, - self.lda, self.ldb, self.ldc, self.ldd, - self.lda, self.ldb, self.ldc, self.ldd, - 0, 0, 0 - ) - else: - arguments = self.operation.argument_type( - # Arguments from UniversalArgumentsBase - self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D, - # Remaining arguments - self.output_op, - int(self.ptr_A), - int(self.ptr_B), - int(self.ptr_C), - int(self.ptr_D), - self.batched_stride_A, - self.batched_stride_B, - self.batched_stride_C, - self.lda, self.ldb, self.ldc, self.ldd, - self.lda, self.ldb, self.ldc, self.ldd, - 0, 0, 0 - ) - - self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size - - def initialize(self): - launch_config = self.operation.rt_module.plan(self) - - # Get the host and device workspace - device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) - - if device_workspace_size > 0: - self.workspace_buffer = device_mem_alloc(device_workspace_size) - workspace_ptr = self.workspace_buffer.ptr - err, = cuda.cuMemsetD32( - workspace_ptr, 0, device_workspace_size // 4) - else: - workspace_ptr = None - - device_workspace = 0 - if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: - # In GEMM splik-K parallel, the D pointer is redirected to the workspace - self.ptr_D = cuda.CUdeviceptr(workspace_ptr) - elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: - device_workspace = workspace_ptr - - self.get_arguments() - - arguments, grid_tiled_shape, gemm_k_size = self.arguments - res_arg = self.operation.rt_module.get_args( - ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace))) - host_workspace = bytearray(res_arg.contents) - - device_workspace = None - - self.host_workspace = host_workspace - self.device_workspace = device_workspace - self.launch_config = launch_config - - def sync(self, stream_sync=True): - super().sync(stream_sync) - if hasattr(self.output_op, "sync"): - self.output_op.sync() - - -class GemmArguments2xStreamK(GemmArguments2x): - """ - Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and - user-provide tensors into the kernel's argument - - :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | - :class:`cutlass_cppgen.backend.GemmOperationGrouped` - - :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass_cppgen.shape.GemmCoord` - - :param A: tensor A - :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param B: tensor B - :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param C: tensor C - :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param D: tensor D - :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` - - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - """ - - def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): - if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: - raise Exception(f"Unsupported GEMM mode {gemm_mode}.") - - super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) - - def get_arguments(self): - batch_stride_A = self.problem_size.m * self.problem_size.k - batch_stride_B = self.problem_size.k * self.problem_size.n - batch_stride_C = self.problem_size.m * self.problem_size.n - batch_stride_D = self.problem_size.m * self.problem_size.n - - arguments = self.operation.argument_type( - self.gemm_mode, - GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k), - self.batch_count, - self.output_op, - int(self.ptr_A), - int(self.ptr_B), - int(self.ptr_C), - int(self.ptr_D), - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D, - self.lda, self.ldb, self.ldc, self.ldd, # strides - self.lda, self.ldb, self.ldc, self.ldd, - -1, # avail_sms - ) - return arguments - - def initialize(self): - # Get the host and device workspace - device_workspace_size = self.operation.rt_module.get_device_workspace_size( - self, - device_sm_count(), - self.operation.rt_module.occupancy - ) - - if device_workspace_size > 0: - self.workspace_buffer = device_mem_alloc(device_workspace_size) - workspace_ptr = self.workspace_buffer.ptr - err, = cuda.cuMemsetD32( - workspace_ptr, 0, device_workspace_size // 4) - else: - workspace_ptr = None - - device_workspace = 0 - if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: - # In GEMM splik-K parallel, the D pointer is redirected to the workspace - self.ptr_D = cuda.CUdeviceptr(workspace_ptr) - elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: - device_workspace = workspace_ptr - - arguments = self.get_arguments() - - res_arg = self.operation.rt_module.get_args( - ctypes.byref(arguments), - ctypes.c_void_p(int(device_workspace)), - device_sm_count(), - self.operation.rt_module.occupancy - ) - host_workspace = bytearray(res_arg.contents) - - grid = self.operation.rt_module.get_grid_shape( - ctypes.byref(arguments), - device_sm_count(), - self.operation.rt_module.occupancy - ) - - device_workspace = None - - self.host_workspace = host_workspace - self.device_workspace = device_workspace - self.launch_config = LaunchConfiguration( - [grid.m, grid.n, grid.k], - [self.operation.rt_module.threads, 1, 1], - self.operation.rt_module.shared_memory_capacity - ) - - -class GemmArguments3x(GemmArguments2x): - """ - Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and - user-provide tensors into the kernel's argument - - :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | - :class:`cutlass_cppgen.backend.GemmOperationGrouped` - - :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass_cppgen.shape.GemmCoord` - - :param A: tensor A - :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param B: tensor B - :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param C: tensor C - :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param D: tensor D - :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param gemm_mode: GEMM mode - :type gemm_mode: GemmUniversalMode - - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - """ - - def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): - if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: - raise Exception(f"Unsupported GEMM mode {gemm_mode}.") - - super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) - - def get_arguments(self): - mainloop_args = get_mainloop_arguments_3x( - self.operation.tile_description.kernel_schedule, - self.operation.A.element, - self.operation.B.element, - self.operation.A.alignment, - self.operation.B.alignment - ) - scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler) - uses_default_epilogue = self.operation.rt_module.uses_default_epilogue() - argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x( - mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue) - - problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) - - if self.batch_count > 1: - bsA = self.batched_stride_A - bsB = self.batched_stride_B - bsC = self.batched_stride_C - bsD = self.batched_stride_D - else: - bsA = 0 - bsB = 0 - bsC = 0 - bsD = 0 - stride_A = StrideBatched_(self.lda, bsA) - stride_B = StrideBatched_(self.ldb, bsB) - stride_C = StrideBatched_(self.ldc, bsC) - stride_D = StrideBatched_(self.ldd, bsD) - - # Superset of potential mainloop arguments - generic_args = GenericMainloopArguments3x_( - int(self.ptr_A), - stride_A, - int(self.ptr_B), - stride_B, - 4 # mma_promotion_interval - ) - - # Set of mainloop arguments needed for this kernel - mainloop = mainloop_args.from_generic_mainloop_args(generic_args) - - if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"): - self.output_op = self.output_op.to_evt_params() - - epilogue = epilogue_args( - self.output_op, - int(self.ptr_C), - stride_C, - int(self.ptr_D), - stride_D, - ) - - # Set hardware info - hw_info_ = hw_info( - 0, device_sm_count(), 0, - dim3_(0,0,0), - dim3_(0,0,0), - ) - - self.arguments = argument_type( - int(self.gemm_mode), - problem_size_, - mainloop, - epilogue, - hw_info_, - scheduler_args - ) - return self.arguments - - def initialize(self): - # Get the host and evice workspace - device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) - - if device_workspace_size > 0: - self.workspace_buffer = device_mem_alloc(device_workspace_size) - workspace_ptr = self.workspace_buffer.ptr - err, = cuda.cuMemsetD32( - workspace_ptr, 0, device_workspace_size // 4) - else: - workspace_ptr = None - - device_workspace = 0 - if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: - # In GEMM splik-K parallel, the D pointer is redirected to the workspace - self.ptr_D = cuda.CUdeviceptr(workspace_ptr) - elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: - device_workspace = workspace_ptr - - self.get_arguments() - res_arg = self.operation.rt_module.get_args( - ctypes.byref(self.arguments), - ctypes.c_void_p(int(device_workspace)), - ) - host_workspace = bytearray(res_arg.contents) - - grid = self.operation.rt_module.get_grid_shape( - ctypes.byref(self.arguments), - ctypes.c_void_p(int(device_workspace)), - ) - block = self.operation.rt_module.get_block_shape() - - device_workspace = None - - self.host_workspace = host_workspace - self.device_workspace = device_workspace - self.launch_config = LaunchConfiguration( - [grid.x, grid.y, grid.z], - [block.x, block.y, block.z], - self.operation.rt_module.shared_memory_capacity, - ) - - -def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): - """ - Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments - or 3x arguments depending on the `arch` field specified in `operation`. - - :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | - :class:`cutlass_cppgen.backend.GemmOperationGrouped` - - :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass_cppgen.shape.GemmCoord` - - :param A: tensor A - :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param B: tensor B - :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param C: tensor C - :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param D: tensor D - :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` - - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - """ - if operation.swizzling_functor == SwizzlingFunctor.StreamK: - if operation.api == ApiVersion.v3x: - raise Exception("Stream K is currently only supported in CUTLASS 2.x") - ArgClass = GemmArguments2xStreamK - else: - ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x - return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) - - -class GemmGroupedArguments: - """ - Argument wrapper for GEMM Grouped. It encodes problem information and - user-provide tensors into the kernel's argument - - :param operation: the GEMM Grouped operation to take the argument - :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped` - - :param problem_size: list of GEMM problem size gemm(M, N, K) - :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`] - - :param A: list of tensor A - :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] - - :param B: list of tensor B - :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] - - :param C: list of tensor C - :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] - - :param D: list of tensor D - :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] - - :param output_op: output operator, optional - :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` - - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - """ - - def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): - # Get number of problems in the group - self.problem_count = len(problem_sizes) - - # Check the input arguments - assert len(A) == self.problem_count - assert len(B) == self.problem_count - assert len(C) == self.problem_count - assert len(D) == self.problem_count - - problem_size_host = [] - self.ptr_A_host = [] - self.ptr_B_host = [] - self.ptr_C_host = [] - self.ptr_D_host = [] - - lda_host = [] - ldb_host = [] - ldc_host = [] - ldd_host = [] - - self.partitions = 1 - - self.operation = operation - - # Get the threadblock - threadblock_shape = operation.tile_description.threadblock_shape - self.threadblock_shape = GemmCoord( - threadblock_shape[0], - threadblock_shape[1], - threadblock_shape[2], - ) - self.threadblock_swizzle = operation.swizzling_functor - - self.total_tiles = 0 - - self.gemm_arguments = [] - - self.stream = kwargs.get("stream", cuda.CUstream(0)) - - # Process the input arguments - for idx, problem_size in enumerate(problem_sizes): - M, N, K = problem_size.m, problem_size.n, problem_size.k - temp_argument = GemmArguments2x( - operation=operation, - problem_size=GemmCoord(M, N, K), - A=A[idx], B=B[idx], C=C[idx], D=D[idx]) - self.gemm_arguments.append(temp_argument) - - problem_size_host.append( - [temp_argument.problem_size.m, - temp_argument.problem_size.n, - temp_argument.problem_size.k] - ) - - self.ptr_A_host.append(int(temp_argument.ptr_A)) - lda_host.append(temp_argument.lda) - - self.ptr_B_host.append(int(temp_argument.ptr_B)) - ldb_host.append(temp_argument.ldb) - - self.ptr_C_host.append(int(temp_argument.ptr_C)) - ldc_host.append(temp_argument.ldc) - - self.ptr_D_host.append(int(temp_argument.ptr_D)) - ldd_host.append(temp_argument.ldd) - - # Get number of tiles - grid = self.operation.rt_module.get_grid_shape( - self.operation.rt_module.get_tiled_shape( - temp_argument.problem_size.ctype, - self.threadblock_shape.ctype, - temp_argument.batch_count - ) - ) - self.total_tiles += grid.x * grid.y * grid.z - - self.problem_size_buffer = todevice(problem_size_host, np.int32) - self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64) - self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64) - self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64) - self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64) - - self.lda_buffer = todevice(lda_host, np.int64) - self.ldb_buffer = todevice(ldb_host, np.int64) - self.ldc_buffer = todevice(ldc_host, np.int64) - self.ldd_buffer = todevice(ldd_host, np.int64) - - if "output_op" in kwargs.keys(): - self.alpha = kwargs["output_op"].alpha - self.beta = kwargs["output_op"].beta - else: - self.alpha = 1.0 - self.beta = 0.0 - - if "output_op" in kwargs.keys(): - self.output_op = kwargs["output_op"] - else: - self.output_op = self.operation.epilogue_type(1.0, 0.0) - - # Get host problem size - self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] - - self.arguments = self.get_arguments() - - self.initialize() - - def get_arguments(self): - return self.operation.argument_type( - self.problem_size_buffer.ptr, - self.problem_count, - self.total_tiles, - self.output_op, - self.ptr_A_buffer.ptr, - self.ptr_B_buffer.ptr, - self.ptr_C_buffer.ptr, - self.ptr_D_buffer.ptr, - self.lda_buffer.ptr, - self.ldb_buffer.ptr, - self.ldc_buffer.ptr, - self.ldd_buffer.ptr, - ctypes.c_void_p(int(self.host_problem_size_ptr)), - ) - - def initialize(self): - # Get launch configuration - launch_config = self.operation.rt_module.plan(self) - - # Get the host and evice workspace - device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) - - if device_workspace_size > 0: - self.workspace_buffer = device_mem_alloc(device_workspace_size) - workspace_ptr = self.workspace_buffer.ptr - err, = cuda.cuMemsetD32( - workspace_ptr, 0, device_workspace_size // 4) - else: - workspace_ptr = None - - if self.operation.precompute_mode == SchedulerMode.Host: - device_workspace_ptr = self.operation.rt_module.host_precompute( - self, self.operation.rt_module.get_workspace_size(self),) - else: - device_workspace_ptr = 0 - - result = self.operation.rt_module.get_args( - ctypes.byref(self.arguments), - self.total_tiles, - ctypes.c_void_p(int(device_workspace_ptr)), - ) - host_workspace = bytearray(result.contents) - - device_workspace = None - - self.host_workspace = host_workspace - self.device_workspace = device_workspace - self.launch_config = launch_config - - def sync(self): - err, = cudart.cudaDeviceSynchronize() - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - for arg in self.gemm_arguments: - arg.sync(stream_sync=False) - - -################################################################################ -# Base class for GEMM runtime module -################################################################################ - - -class GemmRTbase(ExecutableOperation): - """ - GemmRT manages the CUTLASS runtime components - """ - - KernelTemplate = r""" -extern "C" -__global__ void -${operation_name}(${operation_name}${operation_suffix}::Params params) { - - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - ${operation_name}${operation_suffix}::SharedStorage *shared_storage = - reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); - - ${operation_name}${operation_suffix}::invoke(params, *shared_storage); -} - """ - - def __init__(self, operation: "GemmOperation"): - super().__init__(operation) - - self.operation = operation - threadblock_shape = operation.tile_description.threadblock_shape - self.threadblock_shape = GemmCoord( - threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) - self.threadblock_swizzle = operation.swizzling_functor - - # Threads per threadblock - self.threads = operation.tile_description.num_threads - - def emit(self): - return self.emitter.emit(self.operation) - - def can_implement(self, configuration, arguments): - raise NotImplementedError() - - def get_host_workspace_size(self, arguments): - raise NotImplementedError() - - def get_device_workspace_size(self, arguments): - return 0 - - def initialize(self): - err, = cuda.cuFuncSetAttribute( - self.kernel, - attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - value=self.shared_memory_capacity) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError( - f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}" - ) - - -################################################################################ -# Runtime module for GEMM Universal -################################################################################ - - -class GemmRTUniversal(GemmRTbase): - """ - GemmRTUniversal manages the CUTLASS runtime components - """ - - HostTemplate = r""" -extern "C" { - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); - } - - // Get the params as byte array - char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){ - ${operation_name}_base::Params* params; - params = new ${operation_name}_base::Params(*argument, - -1, // SM count. Only used for stream-K - -1 // Occupancy. Only used for stream-K - ); - - // Semaphore holds the pointer to the workspace in the Params struct - params->semaphore = workspace; - - char *bytes = ((char*)(params)); - char *output = new char[sizeof(${operation_name}_base::Params)]; - for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) - output[i] = bytes[i]; - - return output; - } - - cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( - cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { - return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( - problem_size, tile_size, split_k_slices); - } - - dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { - return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); - } -} - """ - - def __init__(self, operation): - super(GemmRTUniversal, self).__init__(operation) - self.extra_funcs = { - "get_tiled_shape": GemmCoord_, - "get_grid_shape": dim3_, - } - self.emitter = EmitGemmUniversalInstance( - "_type", operation.direct_store) - - self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor) - self.argtype = [ - ctypes.POINTER(self.argument_type), - ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p - ] - - def plan(self, arguments): - grid = self.get_tiled_shape( - arguments.problem_size.ctype, - self.threadblock_shape.ctype, - arguments.batch_count - ) - - gemm_k_size = arguments.problem_size.k - if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: - alignk = max(max(128 // DataTypeSize[self.operation.A.element], - 128 // DataTypeSize[self.operation.B.element]), 1) - - gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) // - arguments.batch_count + alignk - 1) // alignk) * alignk - - if gemm_k_size: - grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size - grid = GemmCoord(grid.m, grid.n, grid_z).ctype - - arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k) - grid = self.get_grid_shape(grid) - arguments.gemm_k_size = gemm_k_size - return LaunchConfiguration( - [grid.x, grid.y, grid.z], - [self.threads, 1, 1], - self.shared_memory_capacity) - - def get_device_workspace_size(self, arguments: GemmArguments): - workspace_bytes = 0 - if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel: - workspace_bytes = (DataTypeSize[arguments.operation.C.element] - * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8) - elif (arguments.gemm_mode == GemmUniversalMode.Gemm and - arguments.split_k_slices > 1): - workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y - - return workspace_bytes - - -class GemmRTUniversalStreamK(GemmRTUniversal): - """ - Manages the CUTLASS runtime components for 2.x stream K kernels - """ - - HostTemplate = r""" -extern "C" { - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); - } - - using GemmType = ${operation_name}_base; - - // Get the params as byte array - char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace, - int sm_count, int occupancy) { - GemmType::Params* params; - params = new GemmType::Params(*argument, sm_count, occupancy); - - params->init_workspace(workspace); - - char *bytes = ((char*)(params)); - char *output = new char[sizeof(GemmType::Params)]; - for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) - output[i] = bytes[i]; - - return output; - } - - dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) { - typename GemmType::Params params(*args, device_sms, sm_occupancy); - return params.get_grid_dims(); - } - - uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) { - typename GemmType::Params params(*args, device_sms, sm_occupancy); - return params.get_workspace_size(); - } -} - """ - - def __init__(self, operation: "GemmOperation"): - super(GemmRTUniversalStreamK, self).__init__(operation) - self.extra_funcs = { - "get_grid_shape": GemmCoord_, - "get_kernel_workspace_size": ctypes.c_uint64, - } - self._occupancy = None - self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor) - - @property - def occupancy(self): - if self._occupancy is None: - err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - self.kernel, self.threads, self.shared_memory_capacity, - cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE) - - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError( - "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: " - f"{cuda.cuGetErrorString(err)[1]}") - return self._occupancy - - def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int): - return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy) - - -################################################################################ -# Runtime module for GEMM Universal within CUTLASS 3 -################################################################################ - - -class GemmRTUniversal3x(GemmRTUniversal): - """ - Manages the CUTLASS runtime components for 3.x kernels - """ - - KernelTemplate = r""" - -using Operator = ${operation_name}${operation_suffix}; -extern "C" -__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) -void ${operation_name}(__grid_constant__ typename Operator::Params const params) { - // Dynamic shared memory base pointer - extern __shared__ char smem[]; - - // Declare pointer to dynamic shared memory. - Operator op; - op(params, smem); -} - """ - HostTemplate = r""" -extern "C" { - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return ${operation_name}${operation_suffix}::SharedStorageSize; - } - - using GemmType = ${operation_name}_base; - - bool ${operation_name}_uses_default_epilogue() { - return std::is_same_v; - } - - // Get the workspace size - uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { - return GemmType::get_workspace_size(*argument); - } - - // Get the params as byte array - char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){ - GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace); - char *bytes = ((char*)(¶ms)); - char *output = new char[sizeof(GemmType::Params)]; - for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) - output[i] = bytes[i]; - - return output; - } - - // Get the total number of blocks for a persistent kernel - uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) { - auto problem_shape_MNKL = append<4>(problem, Int<1>{}); - auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = - cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( - problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{}); - return problem_blocks_m * problem_blocks_n * problem_blocks_l; - } - - // Get the grid shape - dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) { - auto tmp_params = GemmType::to_underlying_arguments(*args, workspace); - return GemmType::get_grid_shape(tmp_params); - } - - // Get the block shape - dim3 ${operation_name}_get_block_shape() { - return GemmType::get_block_shape(); - } -} - """ - - def __init__(self, operation): - super(GemmRTUniversal3x, self).__init__(operation) - self.extra_funcs = { - "get_grid_shape": dim3_, - "get_block_shape": dim3_, - "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, - "get_kernel_workspace_size": ctypes.c_uint64, - "uses_default_epilogue": ctypes.c_bool, - } - self.emitter = EmitGemmUniversalInstance3x("_type") - - def get_device_workspace_size(self, arguments: GemmArguments3x): - return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) - - -class EmitGemmUniversalInstance3x: - """Responsible for emitting a CUTLASS 3 template definition""" - - def __init__(self, operation_suffix=""): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cute/tensor.hpp", - "cute/atom/mma_atom.hpp", - "cutlass/numeric_types.h", - "cutlass/gemm/collective/collective_builder.hpp", - "cutlass/gemm/kernel/sm90_tile_scheduler.hpp", - "cutlass/gemm/kernel/gemm_universal.hpp", - "cutlass/epilogue/collective/collective_builder.hpp", - "cutlass/epilogue/collective/default_epilogue.hpp", - "cutlass/epilogue/thread/linear_combination.h" - ] - self.gemm_template_kernel = """ -using namespace cute; - -using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - ${element_accumulator}, ${element_epilogue}, - ${element_c}, ${layout_c}, ${align_c}, - ${element_d}, ${layout_d}, ${align_d}, - ${epilogue_schedule} - >::CollectiveOp; - -using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, - ${element_a}, ${layout_a}, ${align_a}, - ${element_b}, ${layout_b}, ${align_b}, - ${element_accumulator}, - cute::Shape, - cute::Shape, - ${stage_count_type}, - ${kernel_schedule} - >::CollectiveOp; - -// Gemm operator ${operation_name} -using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - ${tile_scheduler} ->; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - self.gemm_template_kernel_visitor = """ -using namespace cute; - -${callback_decl} - -using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - ${element_accumulator}, ${element_epilogue}, - ElementC, StrideC, ${align_c}, - ElementD, StrideD, ${align_d}, - ${epilogue_schedule}, - ${callback_name} - >::CollectiveOp; - -using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, - ${element_a}, ${layout_a}, ${align_a}, - ${element_b}, ${layout_b}, ${align_b}, - ${element_accumulator}, - cute::Shape, - cute::Shape, - ${stage_count_type}, - ${kernel_schedule} - >::CollectiveOp; - -// Gemm operator ${operation_name} -using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - ${tile_scheduler} ->; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - - self.gemm_template_device = self.gemm_template_kernel + """ - -// Define device-level operator -using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>; -""" - - def emit(self, operation): - # Support built-in epilogue functors or user-defined functions - - if operation.tile_description.stages is None or operation.tile_description.stages == 0: - stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>" - else: - stage_count_type = "_" + str(operation.tile_description.stages) - - if operation.emission_type == EmissionType.Kernel: - gemm_template = self.gemm_template_kernel - else: - gemm_template = self.gemm_template_device - - kschedule = KernelScheduleType.ScheduleAuto - eschedule = EpilogueScheduleType.ScheduleAuto - tschedule = TileSchedulerType.Default - if operation.tile_description.kernel_schedule is not None: - kschedule = operation.tile_description.kernel_schedule - if operation.tile_description.epilogue_schedule is not None: - eschedule = operation.tile_description.epilogue_schedule - if operation.tile_description.tile_scheduler is not None: - tschedule = operation.tile_description.tile_scheduler - - emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape - - values = { - "operation_name": operation.procedural_name(), - "operation_suffix": self.operation_suffix, - "element_a": DataTypeTag[operation.A.element], - "layout_a": LayoutTag[operation.A.layout], - "element_b": DataTypeTag[operation.B.element], - "layout_b": LayoutTag[operation.B.layout], - "element_c": DataTypeTag[operation.C.element], - "layout_c": LayoutTag[operation.C.layout], - "element_d": DataTypeTag[operation.epilogue_functor.element_output], - "layout_d": LayoutTag[operation.C.layout], - "element_accumulator": DataTypeTag[operation.accumulator_type()], - "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], - "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - "arch": "cutlass::arch::Sm%d" % operation.arch, - "threadblock_shape_m": str(emit_tile_m), - "threadblock_shape_n": str(emit_tile_n), - "threadblock_shape_k": str(emit_tile_k), - "cluster_m": str(operation.tile_description.cluster_shape[0]), - "cluster_n": str(operation.tile_description.cluster_shape[1]), - "cluster_k": str(operation.tile_description.cluster_shape[2]), - "align_a": str(operation.A.alignment), - "align_b": str(operation.B.alignment), - "align_c": str(operation.C.alignment), - "align_d": str(operation.C.alignment), - "stage_count_type": stage_count_type, - "kernel_schedule": KernelScheduleTag[kschedule], - "epilogue_schedule": EpilogueScheduleTag[eschedule], - "tile_scheduler": TileSchedulerTag[tschedule] - } - if hasattr(operation.epilogue_functor, "visitor"): - callback_name, callback_decl = operation.epilogue_functor.emit(operation) - values["callback_name"] = callback_name - values["callback_decl"] = callback_decl - return SubstituteTemplate(self.gemm_template_kernel_visitor, values) - - else: - values["epilogue_functor"] = operation.epilogue_functor.emit() - return SubstituteTemplate(gemm_template, values) - - -################################################################################################### -# Runtime module for GEMM Grouped -################################################################################################### - - -class GemmRTGrouped(GemmRTbase): - """ - GemmRTGrouped manages the CUTLASS runtime components - """ - - KernelTemplate = r""" -extern "C" -__global__ void -${operation_name}(${operation_name}${operation_suffix}::Params params) { - - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - ${operation_name}${operation_suffix}::SharedStorage *shared_storage = - reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); - - ${operation_name}${operation_suffix} op; - - op(params, *shared_storage); -} - """ - - HostTemplate = r""" - extern "C" { - - // precompute scheduling information - char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) { - char* host_workspace = new char[workspace_bytes]; - ${operation_name}_base::ProblemVisitor::host_precompute( - args.host_problem_sizes, - args.problem_count, - args.threadblock_count, - (void*)host_workspace - ); - return host_workspace; - } - - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); - } - - // Get the params as byte array - char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){ - ${operation_name}_base::Params* params; - params = new ${operation_name}_base::Params(*argument, workspace, tile_count); - - char *bytes = ((char*)(params)); - char *output = new char[sizeof(${operation_name}_base::Params)]; - for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) - output[i] = bytes[i]; - - return output; - } - - cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( - cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { - return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( - problem_size, tile_size, split_k_slices); - } - - dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { - return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); - } - } - """ - - def __init__(self, operation: "GemmOperation"): - super(GemmRTGrouped, self).__init__(operation) - self.extra_funcs = { - "precompute": None, - "get_tiled_shape": GemmCoord_, - "get_grid_shape": dim3_, - } - self.emitter = EmitGemmGroupedInstance("_type") - self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) - self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p] - - def host_precompute(self, arguments, workspace_bytes): - self.precompute.argtype = [ - self.argtype[0], ctypes.c_int, ctypes.c_longlong] - self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes) - - problem_info = self.precompute( - ctypes.byref(arguments.arguments), - arguments.total_tiles, - workspace_bytes) - problem_info_array = bytearray(problem_info.contents) - - # copy to device memory - return todevice(problem_info_array).ptr - - def plan(self, arguments): - return LaunchConfiguration( - [arguments.total_tiles, 1, 1], - [self.threads, 1, 1], - self.shared_memory_capacity, - ) - - def get_workspace_size(self, arguments): - if self.operation.precompute_mode == SchedulerMode.Device: - return 0 - elif self.operation.precompute_mode == SchedulerMode.Host: - total_tiles = arguments.total_tiles - entries_per_block = 1 - return 8 * entries_per_block * total_tiles # three int32_t - - -################################################################################ -# Runtime module for GEMM and grouped GEMM -################################################################################ - - -class GemmOperationBase: - """ - CUTLASS GEMM operation - """ - - def __init__( - self, gemm_kind, arch, tile_description: TileDescription, - A: TensorDescription, B: TensorDescription, C: TensorDescription, - epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, - api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs): - self.operation_kind: OperationKind = OperationKind.Gemm - self.arch: int = arch - self.tile_description: TileDescription = tile_description - self.gemm_kind: GemmKind = gemm_kind - - self.api = api - self.prefix = "3x" if self.api == ApiVersion.v3x else "" - self.emission_type = emission_type - - # Optionally swap the TensorDescriptions for operands A and B and transpose their - # layouts. This is needed to mimic the transpose performed by device::GemmUniversal. - # The code below uses deep copy to avoid overwritting the original TensorDescription - self.switched = (self.api != ApiVersion.v3x and - self.emission_type == EmissionType.Kernel and - C.layout == LayoutType.ColumnMajor) - - self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched) - - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - - if "direct_store" in kwargs: - self.direct_store = kwargs["direct_store"] - else: - self.direct_store = False - - @staticmethod - def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool): - """ - Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set, - A and B are swapped, and the layout of A, B, and C are transposed. - - :param A: description of operand A - :type A: TensorDescription - :param B: description of operand B - :type B: TensorDescription - :param C: description of operand C - :type C: TensorDescription - - :return: descriptions of operands A, B, and C - :rtype: tuple[TileDescription] - """ - if swap: - A_out = copy.deepcopy(B) - B_out = copy.deepcopy(A) - C_out = copy.deepcopy(C) - A_out.layout = transpose_layout(A_out.layout) - B_out.layout = transpose_layout(B_out.layout) - C_out.layout = transpose_layout(C_out.layout) - else: - A_out = copy.deepcopy(A) - B_out = copy.deepcopy(B) - C_out = copy.deepcopy(C) - return A_out, B_out, C_out - - def run(self, arguments: GemmArguments) -> cuda.CUresult: - """ - Configure and launch the cuda kernel with input arguments - """ - if self.emission_type == EmissionType.Device: - raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"') - - err = self.rt_module.run( - arguments.host_workspace, - arguments.device_workspace, - arguments.launch_config, - arguments.stream - ) - - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - return err - - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32, - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - def is_planar_complex(self): - return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) - - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - def core_name(self): - """The basic operation kind is prefixed with a letter indicating the accumulation type.""" - - inst_shape = "" - inst_operation = "" - intermediate_type = "" - - math_operations_map = { - MathOperation.xor_popc: "xor", - } - - if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp): - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else "" - - if self.tile_description.math_instruction.instruction_shape is not None: - if self.api == ApiVersion.v3x and self.arch >= 90: - inst_shape = "%dx%dx%d" % tuple( - self.tile_description.math_instruction.instruction_shape) - else: - inst_shape = "%d%d%d" % tuple( - self.tile_description.math_instruction.instruction_shape) - else: - inst_shape = "Default" - inst_shape += math_op_string - - if (self.tile_description.math_instruction.element_a != self.A.element and - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator): - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) - - def extended_name(self): - """Append data types if they differ from compute type.""" - if self.is_complex(): - extended_name = "${core_name}" - else: - if (self.C.element != self.tile_description.math_instruction.element_accumulator and - self.A.element != self.tile_description.math_instruction.element_accumulator): - extended_name = "${element_c}_${core_name}_${element_a}" - elif (self.C.element == self.tile_description.math_instruction.element_accumulator and - self.A.element != self.tile_description.math_instruction.element_accumulator): - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - "element_a": DataTypeNames[self.A.element], - "element_c": DataTypeNames[self.C.element], - "core_name": self.core_name(), - }) - - return extended_name - - def extended_name_3x(self): - """Generates a string representing the MMA atom. Assumes accumulator type is C type.""" - extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( - element_a=DataTypeNames[self.A.element], - element_b=DataTypeNames[self.B.element], - element_acc=DataTypeNames[self.accumulator_type()], - element_c=DataTypeNames[self.C.element], - element_d=DataTypeNames[self.epilogue_functor.element_output], - core_name=self.core_name()) - return extended_name - - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] - ) - return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - - # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) - def layout_name_3x(self): - if self.is_complex() or self.is_planar_complex(): - return "{}{}{}".format( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], - ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) - else: - return "{}{}{}".format( - ShortLayoutTypeNames[self.A.layout], - ShortLayoutTypeNames[self.B.layout], - ShortLayoutTypeNames[self.C.layout]) - - # Generates a short string representing underlying kernel schedule type - def kernel_schedule_name_3x(self): - if self.tile_description.kernel_schedule is None: - return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto] - else: - return KernelScheduleSuffixes[self.tile_description.kernel_schedule] - - # Generates a short string representing underlying epilogue schedule type - def epilogue_schedule_name_3x(self): - if self.tile_description.epilogue_schedule is None: - return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto] - else: - return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule] - - def procedural_name(self): - """The full procedural name indicates architecture, extended name, tile size, and layout.""" - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - if self.api == ApiVersion.v3x and self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" - return kernel_name_template.format( - p=self.prefix, - ar=self.arch, - op=opcode_class_name, - ex=self.extended_name_3x(), - tbm=self.tile_description.threadblock_shape[0], - tbn=self.tile_description.threadblock_shape[1], - tbk=self.tile_description.threadblock_shape[2], - cm=self.tile_description.cluster_shape[0], - cn=self.tile_description.cluster_shape[1], - ck=self.tile_description.cluster_shape[2], - l=self.tile_description.stages, - s=self.layout_name_3x(), - al=str(self.A.alignment), - k=self.kernel_schedule_name_3x(), - e=self.epilogue_schedule_name_3x() - ) - else: - threadblock = self.tile_description.procedural_name_2x() - return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( - p=self.prefix, - op=opcode_class_name, - ex=self.extended_name(), - tb=threadblock, - l=self.layout_name(), - a=str(self.A.alignment) - ) - - def configuration_name(self): - """The full procedural name indicates architecture, extended name, tile size, and layout.""" - return self.procedural_name() - - -class GemmOperationUniversal(GemmOperationBase): - def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, - epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): - api = api_version(arch, tile_description.math_instruction.opcode_class, A.element) - super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, - A, B, C, epilogue_functor, swizzling_functor, - api=api, **kwargs, ) - if api == ApiVersion.v3x: - if swizzling_functor == SwizzlingFunctor.StreamK: - raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels") - self.rt_module = GemmRTUniversal3x(self) - else: - if swizzling_functor == SwizzlingFunctor.StreamK: - self.rt_module = GemmRTUniversalStreamK(self) - else: - self.rt_module = GemmRTUniversal(self) - self.argument_type = self.rt_module.argument_type - self.epilogue_type = self.rt_module.epilogue_type - - def device_op(self): - """ - Returns a new GemmOperationUniversal object that is constructed with emission type - ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, - any swappng performed by the kernel-emitted operation is reversed. - - :return: operation ready for device-level code emission - :rtype: GemmUniversalOperation - """ - A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) - return GemmOperationUniversal(self.arch, self.tile_description, A, B, C, - self.epilogue_functor, self.swizzling_functor, - emission_type=EmissionType.Device, direct_store=self.direct_store) - - -class GemmOperationGrouped(GemmOperationBase): - def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, - epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): - super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, - A, B, C, epilogue_functor, swizzling_functor, **kwargs) - assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'." - self.precompute_mode = kwargs["precompute_mode"] - self.rt_module = GemmRTGrouped(self) - self.argument_type = self.rt_module.argument_type - self.epilogue_type = self.rt_module.epilogue_type - - def device_op(self): - """ - Returns a new GemmOperationGrouped object that is constructed with emission type - ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, - any swappng performed by the kernel-emitted operation is reversed. - - :return: operation ready for device-level code emission - :rtype: GemmOperationGrouped - """ - A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) - return GemmOperationGrouped( - self.arch, self.tile_description, A, B, C, self.epilogue_functor, - self.swizzling_functor, emission_type=EmissionType.Device, - direct_store=self.direct_store, precompute_mode=self.precompute_mode, ) - - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - - -class EmitGemmUniversalInstance: - """Responsible for emitting a CUTLASS template definition""" - - def __init__( - self, - operation_suffix="", - direct_store=False - ): - self.operation_suffix = operation_suffix - self.direct_store = direct_store - self.includes = [ - "cutlass/cutlass.h", - "cutlass/gemm_coord.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/device/gemm.h", - "cutlass/gemm/device/gemm_universal_adapter.h", - "cutlass/gemm/kernel/default_gemm_universal.h", - ] - if self.direct_store: - self.includes.append( - "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" - ) - self.gemm_template_kernel = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - - self.gemm_template_device = """ -// Gemm operator ${operation_name} -using DeviceKernel = - typename cutlass::gemm::device::GemmUniversal< - // Data type and layout of operand A - ${element_a}, ${layout_a}, - // Data type and layout of operand B - ${element_b}, ${layout_b}, - // Data type and layout of operand C - ${element_c}, ${layout_c}, - // Data type of accumulator - ${element_accumulator}, - // Class of operation - ${opcode_class}, - // Compute capability of the target kernel - ${arch}, - // Threadblock tile shape - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - // Warp tile shape - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - // Instruction shape - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - // Epilogue functor - ${epilogue_functor}, - // Swizzling function - ${swizzling_functor}, - // Number of pipeline stages - ${stages}, - // Alignment of operands A and B - ${align_a}, ${align_b}, - // Type of math operation - ${math_operation}, - // Complex transform types of operands A and B - ${transform_a}, ${transform_b} - >; -""" - self.gemm_template_direct_store = """ -// Gemm operator ${operation_name} -using ${operation_name}_default = - typename cutlass::gemm::kernel::DefaultGemmUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operation} ->::GemmKernel; - -using ${operation_name}_base = - cutlass::gemm::kernel::GemmUniversal< - ${operation_name}_default::Mma, - cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< - ${operation_name}_default::Epilogue - >::Epilogue, - ${operation_name}_default::ThreadblockSwizzle - >; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - self.gemm_template_kernel_visitor = """ - -using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - ${element_c}, - ${align_c}, - ${epilogue_stages} /* epilogue stages */ ->; - -${callback_decl} - -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, ${align_c}, - ${element_accumulator}, - ${element_epilogue}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${callback_name}, - ${swizzling_functor}, - ${stages}, - ${math_operation}, - ${epilogue_stages} /* epilogue stages */ ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - def emit(self, operation): - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) - - if operation.emission_type == EmissionType.Kernel: - if self.direct_store: - gemm_template = self.gemm_template_direct_store - else: - gemm_template = self.gemm_template_kernel - else: - gemm_template = self.gemm_template_device - - values = { - "operation_name": operation.procedural_name(), - "operation_suffix": self.operation_suffix, - "element_a": DataTypeTag[operation.A.element], - "layout_a": LayoutTag[instance_layout_A], - "element_b": DataTypeTag[operation.B.element], - "layout_b": LayoutTag[instance_layout_B], - "element_c": DataTypeTag[operation.C.element], - "layout_c": LayoutTag[instance_layout_C], - "element_accumulator": DataTypeTag[operation.accumulator_type()], - "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - "arch": "cutlass::arch::Sm%d" % operation.arch, - "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), - "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), - "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), - "warp_shape_m": str(warp_shape[0]), - "warp_shape_n": str(warp_shape[1]), - "warp_shape_k": str(warp_shape[2]), - "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), - "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), - "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), - "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], - "stages": str(operation.tile_description.stages), - "align_a": str(operation.A.alignment), - "align_b": str(operation.B.alignment), - "transform_a": ComplexTransformTag[operation.A.complex_transform], - "transform_b": ComplexTransformTag[operation.B.complex_transform], - "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], - } - - if hasattr(operation.epilogue_functor, "visitor"): - self.includes += [ - "cutlass/epilogue/threadblock/fusion/visitors.hpp", - "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" - ] - callback_name, callback_decl = operation.epilogue_functor.emit(operation) - values["callback_name"] = callback_name - values["callback_decl"] = callback_decl - values["align_c"] = str(operation.C.alignment) - values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue] - if hasattr(operation.epilogue_functor, "epilogue_stages"): - epilogue_stages = operation.epilogue_functor.epilogue_stages - else: - epilogue_stages = 1 - values["epilogue_stages"] = str(epilogue_stages) - return SubstituteTemplate(self.gemm_template_kernel_visitor, values) - else: - values["epilogue_functor"] = operation.epilogue_functor.emit() - return SubstituteTemplate(gemm_template, values) - - -class EmitGemmGroupedInstance: - """Responsible for emitting a CUTLASS template definition""" - - def __init__(self, operation_suffix=""): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/kernel/gemm_grouped.h", - "cutlass/gemm/kernel/default_gemm_grouped.h", - ] - self.gemm_template_kernel = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmGrouped< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${precompute_mode}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - self.gemm_template_device = ( - self.gemm_template_kernel - + """ -using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>; -""" - ) - - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmGrouped<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - def emit(self, operation): - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) - - # Support built-in epilogue functors or user-defined functions - epilogue_functor = operation.epilogue_functor.emit() - - values = { - "operation_name": operation.procedural_name(), - "operation_suffix": self.operation_suffix, - "element_a": DataTypeTag[operation.A.element], - "layout_a": LayoutTag[instance_layout_A], - "element_b": DataTypeTag[operation.B.element], - "layout_b": LayoutTag[instance_layout_B], - "element_c": DataTypeTag[operation.C.element], - "layout_c": LayoutTag[instance_layout_C], - "element_accumulator": DataTypeTag[operation.accumulator_type()], - "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - "arch": "cutlass::arch::Sm%d" % operation.arch, - "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), - "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), - "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), - "warp_shape_m": str(warp_shape[0]), - "warp_shape_n": str(warp_shape[1]), - "warp_shape_k": str(warp_shape[2]), - "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), - "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), - "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), - "epilogue_functor": epilogue_functor, - "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], - "stages": str(operation.tile_description.stages), - "align_a": str(operation.A.alignment), - "align_b": str(operation.B.alignment), - "transform_a": ComplexTransformTag[operation.A.complex_transform], - "transform_b": ComplexTransformTag[operation.B.complex_transform], - "precompute_mode": SchedulerModeTag[operation.precompute_mode], - "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], - } - - if operation.emission_type == EmissionType.Kernel: - gemm_template = self.gemm_template_kernel - else: - gemm_template = self.gemm_template_device - - return SubstituteTemplate(gemm_template, values) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py deleted file mode 100644 index a77b302dcccf330cc0e0f9b3f1290ab7030c5932..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py +++ /dev/null @@ -1,509 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Common data types and string names/tags for them -""" - -import enum - -from cutlass_library import ( - ComplexTransform, - DataType, - DataTypeSize, - EpilogueScheduleType, - KernelScheduleSuffixes, - KernelScheduleType, - MathOperation, - OpcodeClass, - TileSchedulerType -) - - -# The following block implements enum.auto() for Python 3.5 variants that don't include it such -# as the default 3.5.2 on Ubuntu 16.04. -# -# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility - -try: - from enum import auto as enum_auto -except ImportError: - __cutlass_library_auto_enum = 0 - - def enum_auto() -> int: - global __cutlass_library_auto_enum - i = __cutlass_library_auto_enum - __cutlass_library_auto_enum += 1 - return i - - -class DataTypeSizeBytes: - """ - Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the - data type key is less than a full byte or a non-integer number of bytes. - """ - - @staticmethod - def __class_getitem__(datatype): - """ - Returns the number of bytes in size the data type is. Raises an exception if the data type - is either less than a full byte or a non-integer number of bytes in size. - - :param datatype: data type to query - - :return: number of bytes the data type occupies - :rtype: int - """ - bits = DataTypeSize[datatype] - if bits < 8: - raise Exception( - f"Data type {datatype} is less than one byte in size." - ) - elif bits % 8 != 0: - raise Exception( - f"Data type datatype is not an integer number of bytes." - ) - return bits // 8 - - -class SchedulerMode(enum.Enum): - Device = enum_auto() - Host = enum_auto() - - -SchedulerModeTag = { - SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", - SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", -} - - -ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"} - - -class FunctionalOp(enum.Enum): - AtomicAdd = enum_auto() - AtomicMaximum = enum_auto() - Divides = enum_auto() - Maximum = enum_auto() - Minimum = enum_auto() - Minus = enum_auto() - Multiplies = enum_auto() - MultiplyAdd = enum_auto() - Plus = enum_auto() - Exp = enum_auto() - - -FunctionalOpTag = { - FunctionalOp.AtomicAdd: "cutlass::atomic_add", - FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum", - FunctionalOp.Divides: "cutlass::divides", - FunctionalOp.Maximum: "cutlass::maximum", - FunctionalOp.Minimum: "cutlass::minimum", - FunctionalOp.Minus: "cutlass::minus", - FunctionalOp.Multiplies: "cutlass::multiplies", - FunctionalOp.MultiplyAdd: "cutlass::multiply_add", - FunctionalOp.Plus: "cutlass::plus", - FunctionalOp.Exp: "cutlass::fast_exp_op", -} - - -class ActivationOp(enum.Enum): - DGelu = enum_auto() - Gelu = enum_auto() - GeluTaylor = enum_auto() - HardSwish = enum_auto() - Identity = enum_auto() - LeakyReLU = enum_auto() - ReLU = enum_auto() - Sigmoid = enum_auto() - SiLU = enum_auto() - Tanh = enum_auto() - - -ActivationOpTag = { - ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU", - ActivationOp.Gelu: "cutlass::epilogue::thread::GELU", - ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor", - ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish", - ActivationOp.Identity: "cutlass::epilogue::thread::Identity", - ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU", - ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu", - ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid", - ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu", - ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh", -} - - -def op_tag(op) -> str: - """ - Dispatches `op` to the appropriate *Tag dictionary depending on whether - `op` is an ActivationOp or FunctionalOp. This is useful for cases in which - either type can be used. - - :param op: operation to emit a tag for - :type op: ActivationOp | FunctionalOp - - :return: tag corresponding to op - :rtype: str - """ - if isinstance(op, ActivationOp): - return ActivationOpTag[op] - elif isinstance(op, FunctionalOp): - return FunctionalOpTag[op] - else: - raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.") - - -class FloatRoundStyle(enum.Enum): - ToNearest = enum_auto() - ToNearestSatfinite = enum_auto() - Indeterminate = enum_auto() - TowardZero = enum_auto() - TowardInfinity = enum_auto() - TowardNegInfinity = enum_auto() - HalfUlpTruncDntz = enum_auto() - HalfUlpTruncate = enum_auto() - - -FloatRoundStyleTag = { - FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest", - FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite", - FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate", - FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero", - FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity", - FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity", - FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz", - FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate", -} - - -class MathInstruction: - """ - Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel - """ - - def __init__( - self, - instruction_shape, - element_a, - element_b, - element_accumulator, - opcode_class=OpcodeClass.Simt, - math_operation=MathOperation.multiply_add, - ): - """ - :param instruction_shape: size of the [M, N, K] dimensions of the instruction - :type instruction_shape: list or tuple - :param element_a: data type of operand A - :param element_b: data type of operand B - :param element_accumulator: data type used in accumulation - :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) - :type opcode_class: cutlass_library.library.OpcodeClass - :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) - :type math_operation: MathOperation - """ - self.instruction_shape = instruction_shape - self.element_a = element_a - self.element_b = element_b - self.element_accumulator = element_accumulator - self.opcode_class = opcode_class - self.math_operation = math_operation - - -def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule): - blackwell_threadblock_shape = tile_description.threadblock_shape - is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule]) - if cluster_shape[0] > 0: - blackwell_threadblock_shape = [ - tile_description.threadblock_shape[0] // cluster_shape[0], - tile_description.threadblock_shape[1] // cluster_shape[1], - tile_description.threadblock_shape[2] // cluster_shape[2] - ] - if is_2sm: - blackwell_threadblock_shape[0] *= 2 - else: - blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape - return blackwell_threadblock_shape, is_2sm - - -class TileDescription: - """ - Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, - stage count, and math instruction specification - """ - - def __init__( - self, - threadblock_shape, - stages, - warp_count, - math_instruction, - cluster_shape=[1, 1, 1], - kernel_schedule: KernelScheduleType = None, - epilogue_schedule: EpilogueScheduleType = None, - tile_scheduler: TileSchedulerType = None - ): - """ - :param threadblock_shape: shape of a threadblock tyle - :type threadblock_shape: list or tuple - :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum - number of stages that can be supported for an operation on a given architecture will be computed at a later time - :type stages: int or None - :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile - :type warp_count: list, tuple, or None - :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands - :type math_instruction: MathInstruction - :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster - :param kernel_schedule: type of kernel schedule to use (only available for SM90+) - :type kernel_schedule: cutlass_library.KernelScheduleType - :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+) - :type epilogue_schedule: cutlass_library.EpilogueScheduleType - :param tile_scheduler: type of tile scheduler to use (only available for SM90+) - :type tile_scheduler: cutlass_library.TileSchedulerType - """ - if ((kernel_schedule is None and epilogue_schedule is not None) or - (kernel_schedule is not None and epilogue_schedule is None)): - raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.") - - self.threadblock_shape = threadblock_shape - self.cluster_shape = cluster_shape - self.kernel_schedule = kernel_schedule - self.epilogue_schedule = epilogue_schedule - self.tile_scheduler = tile_scheduler - self.stages = stages - - self.math_instruction = math_instruction - self.instruction_shape = math_instruction.instruction_shape - - # Number of warps along x, y, z directions - self.warp_count = warp_count - - self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule) - - def clone_and_update(self, td: dict): - attrs = { - "cluster_shape": None, - "threadblock_shape": None, - "warp_count": None, - "stages": None, - "instruction_shape": None, - "kernel_schedule": None, - "epilogue_schedule": None, - "tile_scheduler": None - } - for key in attrs.keys(): - if key in td.keys(): - attrs[key] = td[key] - else: - attrs[key] = getattr(self, key) - - attrs["math_instruction"] = MathInstruction( - attrs["instruction_shape"], - self.math_instruction.element_a, - self.math_instruction.element_b, - self.math_instruction.element_accumulator, - self.math_instruction.opcode_class, - self.math_instruction.math_operation - ) - - # Remove the instruction shape - del attrs["instruction_shape"] - - return TileDescription(**attrs) - - @property - def num_threads(self): - """ - Returns the number of threads in the threadblock - - :return: number of threads in the threadblock - :rtype: int or None (if warp count is None) - """ - if self.warp_count is not None: - threads = 32 - for cnt in self.warp_count: - threads *= cnt - return threads - return None - - def procedural_name(self): - """ - Returns a name identifying the tile description - - :return: name identifying the tile description - :rtype: int - """ - emit_stages = 0 if self.stages is None else self.stages - name = "%dx%dx%d_%dx%d_%dx%d" % ( - self.cluster_shape[0], - self.cluster_shape[1], - self.cluster_shape[2], - self.threadblock_shape[0], - self.threadblock_shape[1], - self.threadblock_shape[2], - emit_stages - ) - - return name - - def procedural_name_2x(self): - """ - Returns a name identifying the tile description - - :return: name identifying the tile description - :rtype: int - """ - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) - - def __str__(self): - """ - Returns a string with containing each of the tile description's values - - :return: contents of tile description - :rtype: str - """ - if self.kernel_schedule is not None: - kschedule = self.kernel_schedule - else: - kschedule = KernelScheduleType.ScheduleAuto - - if self.epilogue_schedule is not None: - eschedule = self.epilogue_schedule - else: - eschedule = EpilogueScheduleType.ScheduleAuto - - if self.tile_scheduler is not None: - tschedule = self.tile_scheduler.name - else: - tschedule = "None" - return f""" -{{ - ClusterShape: {self.cluster_shape} - ThreadblockShape: {self.threadblock_shape} - WarpCount: {self.warp_count} - Stages: {self.stages if self.stages is not None else 'Auto'} - InstructionShape: {self.math_instruction.instruction_shape} - Kernel schedule: {kschedule.name} - Epilogue schedule: {kschedule.name} - TileScheduler: {tschedule} -}}""" - - -class TensorDescription: - def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none): - self.element = element - self.layout = layout - if element != DataType.void: - self.alignment = min(128 // DataTypeSize[self.element], alignment) - else: - self.alignment = alignment - self.complex_transform = complex_transform - - -def CalculateSmemUsagePerStage(operation): - """ - Returns the amount of shared memory in bytes consumed in a single stage of a kernel. - - :param op: operation for which the maximum stages should be computed. If stages are - set via the `op.tile_description.stages` parameter, this setting is ignored - in the present calculation - :type op: cutlass_cppgen.backend.Operation - - :return: number of bytes of shared memory consumed by a single stage - :rtype: int - """ - m, n, k = operation.tile_description.threadblock_shape - - if operation.operation_kind == OperationKind.Gemm: - stage_barrier_bytes = 32 - return ( - (DataTypeSize[operation.A.element] * m * k // 8) - + (DataTypeSize[operation.B.element] * k * n // 8) - + stage_barrier_bytes - ) - else: - raise Exception("Unsupported operation kind {}.".format(operation.operation_kind)) - - -def CalculateSmemUsage(operation): - """ - Returns the amount of shared memory in bytes consumed by a kernel. - - :param op: operation for which the maximum stages should be computed. If stages are - set via the `op.tile_description.stages` parameter, this setting is ignored - in the present calculation - :type op: cutlass_cppgen.backend.Operation - - :return: int - """ - return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) - - -class ApiVersion(enum.Enum): - """ - Differentiate between CUTLASS 2.x and 3.x API versions - """ - - v2x = enum_auto() - v3x = enum_auto() - - -def api_version(arch, opclass, dtype): - """ - Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x - or 3.x for code emission. - - :param arch: compute capability of device on which to run - :type arch: int - :param opclass: class of the operation being performed - :type opclass: cutlass_library.OpcodeClass - :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same) - :type dtype: cutlass_library.DataType - - :return: API version to be used in code emission - :rtype: ApiVersion - """ - if (arch in [90, 100, 101, 103] and - opclass == OpcodeClass.TensorOp and - (dtype != DataType.f64)): - return ApiVersion.v3x - else: - return ApiVersion.v2x - - -class EmissionType(enum.Enum): - """ - Tags for whether to emit a kernel- or device-level operation - """ - - Kernel = enum_auto() - Device = enum_auto() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py deleted file mode 100644 index 30e6bb3108ddd30e3776cf92b0671fce4fae5a93..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py +++ /dev/null @@ -1,121 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import numpy as np - -import cutlass_cppgen -from cutlass_cppgen.utils.datatypes import is_numpy_tensor -from cutlass_cppgen.utils.lazy_import import lazy_import - -if cutlass_cppgen.use_rmm: - import rmm -else: - cudart = lazy_import("cuda.cudart") - - -class PoolMemoryManager: - def __init__(self, init_pool_size: int, max_pool_size: int) -> None: - self.pool = rmm.mr.PoolMemoryResource( - rmm.mr.CudaMemoryResource(), - initial_pool_size=init_pool_size, - maximum_pool_size=max_pool_size - ) - self.mr = rmm.mr.TrackingResourceAdaptor(self.pool) - rmm.mr.set_current_device_resource(self.mr) - - def pool_size(self): - return self.pool.pool_size() - - -class DevicePtrWrapper: - """ - Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer - (at least in terms of the interface used by the CUTLASS Python interface) - """ - def __init__(self, dev_ptr): - self.dev_ptr = dev_ptr - - @property - def ptr(self): - return self.dev_ptr - - -def _todevice(host_data): - """ - Helper for transferring host data to device memory - """ - if cutlass_cppgen.use_rmm: - return rmm.DeviceBuffer.to_device(host_data.tobytes()) - else: - nbytes = len(host_data.tobytes()) - dev_ptr_wrapper = device_mem_alloc(nbytes) - err, = cudart.cudaMemcpy( - dev_ptr_wrapper.ptr, - host_data.__array_interface__['data'][0], - nbytes, - cudart.cudaMemcpyKind.cudaMemcpyHostToDevice - ) - if err != cudart.cudaError_t.cudaSuccess: - raise Exception(f"cudaMemcpy failed with error {err}") - return dev_ptr_wrapper - - -def todevice(host_data, dtype=np.float32): - """ - Pass the host_data to device memory - """ - if isinstance(host_data, list): - return _todevice(np.array(host_data, dtype=dtype)) - elif is_numpy_tensor(host_data): - return _todevice(host_data) - - -def device_mem_alloc(size): - if cutlass_cppgen.use_rmm: - return rmm.DeviceBuffer(size=size) - else: - err, ptr = cudart.cudaMalloc(size) - if err != cudart.cudaError_t.cudaSuccess: - raise Exception(f"cudaMalloc failed with error {err}") - return DevicePtrWrapper(ptr) - - -def align_size(size, alignment=256): - return ((size + alignment - 1) // alignment) * alignment - - -def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): - if cutlass_cppgen.use_rmm: - memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size) - return memory_pool - else: - return None diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py deleted file mode 100644 index 10ee67bc6f547d079b6d990e7abea69a16549c16..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py +++ /dev/null @@ -1,140 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import ctypes -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") - -from cutlass_cppgen.backend.utils.device import device_cc - -_supports_cluster_launch = None - - -def supports_cluster_launch(): - from cuda import __version__ - _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] - global _supports_cluster_launch - if _supports_cluster_launch is None: - major, minor = _version_splits[0], _version_splits[1] - _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8)) - return _supports_cluster_launch - - -class LaunchConfiguration: - def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0): - self.grid = grid - self.block = block - self.shared_memory_capacity = smem - - -class ExecutableOperation: - def __init__(self, operation): - self.operation = operation - self.module = None - self.kernel = None - - def name(self): - return self.operation.procedural_name() - - def emit(self): - return "" - - def can_implement(self, configuration, arguments): - raise NotImplementedError() - - def get_host_workspace_size(self, arguments): - raise NotImplementedError() - - def get_device_workspace_size(self, arguments): - raise NotImplementedError() - - def plan(self, arguments): - raise NotImplementedError() - - def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None): - raise NotImplementedError() - - def run_with_clusters(self, launch_config, kernel_params, stream=None): - if not stream: - stream = cuda.CUstream(0) - if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"): - attr = cuda.CUlaunchAttribute() - attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape - attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attrs = [attr] - - # Allow for non-portable cluster sizes - err, = cuda.cuFuncSetAttribute( - self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1) - if err != cuda.CUresult.CUDA_SUCCESS: - return err - else: - attrs = [] - - config = cuda.CUlaunchConfig() - config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid - config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block - config.blockDimZ = launch_config.block[2] - config.sharedMemBytes = launch_config.shared_memory_capacity - config.hStream = stream - config.attrs = attrs - config.numAttrs = len(attrs) - - err, = cuda.cuLaunchKernelEx( - config, f=self.kernel, kernelParams=kernel_params, extra=0) - return err - - def run_without_clusters(self, launch_config, kernel_params, stream=None): - if not stream: - stream = cuda.CUstream(0) - err, = cuda.cuLaunchKernel( - self.kernel, - launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], - launch_config.block[0], launch_config.block[1], launch_config.block[2], - launch_config.shared_memory_capacity, - stream, - kernel_params, - 0) - - return err - - def run(self, host_workspace, device_workspace, launch_config, stream=None): - if not stream: - stream = cuda.CUstream(0) - cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) - packed = (ctypes.c_void_p * 1)() - packed[0] = ctypes.addressof(cArg) - - if supports_cluster_launch(): - return self.run_with_clusters(launch_config, packed, stream) - else: - return self.run_without_clusters(launch_config, packed, stream) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py deleted file mode 100644 index 535cea2cb2a23ccbb29cce7233f42147ed2ea5eb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py +++ /dev/null @@ -1,455 +0,0 @@ -################################################################################ -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ -from __future__ import annotations - -import ctypes -from typing import Union - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -import numpy as np - -from cutlass_library import ( - DataTypeNames, - DataTypeSize, - DataTypeTag, - LayoutType, - SubstituteTemplate -) - -import cutlass_cppgen -from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params -from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend -from cutlass_cppgen.backend.library import TensorDescription -from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper -from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass_cppgen.shape import MatrixCoord -from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor - - -class ReductionOperation: - pass - - -class ReductionArguments: - """ - Arguments of reduction - """ - - def __init__( - self, - operation: ReductionOperation, - problem_size: "list[int]", - partitions: int, - workspace: cuda.CUdeviceptr, - destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", - source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", - **kwargs, - ) -> None: - # tensor_C can be interpreted as the bias with bias=True in keyword args - if "bias" in kwargs.keys(): - self.bias = kwargs["bias"] - else: - # by default, tensor_C is not bias - self.bias = False - if "stream" in kwargs.keys(): - self.stream = kwargs["stream"] - else: - self.stream = cuda.CUstream(0) - - self.operation = operation - self.ptr_workspace = workspace - - # number of split-k partitions - self.partitions = partitions - - if is_numpy_tensor(destination): - self.host_D = destination - self.destination_buffer = NumpyFrontend.argument(destination, True) - self.source_buffer = NumpyFrontend.argument(source, False) - self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr) - self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) - elif is_torch_tensor(destination): - self.ptr_destination = TorchFrontend.argument(destination) - self.ptr_source = TorchFrontend.argument(source) - elif isinstance(destination, cuda.CUdeviceptr): - self.ptr_destination = destination - self.ptr_source = source - else: - raise TypeError("unknown Type") - - self.problem_size = MatrixCoord_(problem_size[0], problem_size[1]) - - self.partition_stride = ( - problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8 - ) - - if "output_op" in kwargs.keys(): - self.output_op = kwargs["output_op"] - else: - self.output_op = self.operation.epilogue_type(1.0, 0.0) - - self.get_arguments() - - @staticmethod - def get_tensor_ref( - extent: "tuple[int]", - device_ptr: cuda.CUdeviceptr, - layout: LayoutType, - ): - if layout == LayoutType.RowMajor: - return TensorRef2D_(int(device_ptr), extent[1]) - else: - raise ValueError(f"Unknown layout type {layout}") - - def get_arguments(self): - ref_workspace = ReductionArguments.get_tensor_ref( - extent=[ - self.problem_size.row, - self.problem_size.column, - ], - device_ptr=self.ptr_workspace, - layout=LayoutType.RowMajor, - ) - if self.bias: - ref_source = ReductionArguments.get_tensor_ref( - extent=[0, 0], - device_ptr=self.ptr_source, - layout=LayoutType.RowMajor, - ) - else: - ref_source = ReductionArguments.get_tensor_ref( - extent=[ - self.problem_size.row, - self.problem_size.column, - ], - device_ptr=self.ptr_source, - layout=LayoutType.RowMajor, - ) - - ref_destination = ReductionArguments.get_tensor_ref( - extent=[ - self.problem_size.row, - self.problem_size.column, - ], - device_ptr=self.ptr_destination, - layout=LayoutType.RowMajor, - ) - - self.c_arguments = self.operation.argument_type( - self.problem_size, - self.partitions, - self.partition_stride, - ref_workspace, - ref_destination, - ref_source, - self.output_op, - ) - - params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments)) - self.host_workspace = bytearray(params_.contents) - - def sync(self): - (err,) = cudart.cudaDeviceSynchronize() - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - - if hasattr(self, "host_D"): - (err,) = cuda.cuMemcpyDtoH( - self.host_D, - self.ptr_destination, - self.host_D.size * self.host_D.itemsize, - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - self.free() - - def free(self): - """ - Frees allocated device-side memory - """ - # Free any device memory allocated manually - if not cutlass_cppgen.use_rmm: - for attr in ["destination_buffer", "source_buffer"]: - if hasattr(self, attr): - buf = getattr(self, attr) - if isinstance(buf, DevicePtrWrapper): - err, = cudart.cudaFree(buf.ptr) - if err != cudart.cudaError_t.cudaSuccess: - raise RuntimeError(f"cudaFree failed with error {err}") - del buf - - -class ReductionRT(ExecutableOperation): - """ - ReductionRT manages the CUTLASS runtime components for reduction - """ - - KernelTemplate = r""" -extern "C" -__global__ void -${operation_name}(${operation_name}${operation_suffix}::Params params) { - - // Dynamic shared memory base pointer - extern __shared__ int SharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - ${operation_name}${operation_suffix}::SharedStorage *shared_storage = - reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); - - ${operation_name}${operation_suffix} op; - - op(params, *shared_storage); -} - """ - HostTemplate = r""" -extern "C" { - // Get the size of params in bytes - int ${operation_name}_get_param_size(){ - return sizeof(${operation_name}${operation_suffix}::Params); - } - - // Get the size of dynamic shared memory in bytes - int ${operation_name}_shared_memory_size() { - return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); - } - - // Get the params as byte array - char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){ - char *bytes = ((char*)(params)); - char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; - for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) - output[i] = bytes[i]; - - return output; - } -} - """ - - def __init__(self, operation: ReductionOperation): - super().__init__(operation) - - self.operation: ReductionOperation = operation - self.emitter = EmitReductionInstance("_type") - - self.elements_per_access = self.operation.count - ( - self.argument_type, - self.epilogue_type, - ) = get_reduction_params(operation.epilogue_functor) - self.argtype = [ctypes.POINTER(self.argument_type)] - - def emit(self): - return self.emitter.emit(self.operation) - - def plan(self, arguments: ReductionArguments): - block_shape = [ - self.operation.shape.column // self.elements_per_access, - self.operation.shape.row, - 1, - ] - grid_shape = [ - (arguments.problem_size.row + self.operation.shape.row - 1) - // self.operation.shape.row, - (arguments.problem_size.column + self.operation.shape.column - 1) - // self.operation.shape.column, - 1, - ] - return LaunchConfiguration( - grid_shape, - block_shape, - self.shared_memory_capacity, - ) - - def initialize(self): - (err,) = cuda.cuFuncSetAttribute( - self.kernel, - attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - value=self.shared_memory_capacity, - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error: {err}") - - -class ReductionOperation: - """ - CUTLASS reduction Operation - """ - - def __init__( - self, - shape: MatrixCoord, - C: TensorDescription, - element_accumulator, - element_workspace=None, - element_compute=None, - epilogue_functor=None, - count: int = 1, - partitions_per_stage: int = 4, - ) -> None: - self.shape = shape - self.epilogue_functor = epilogue_functor - self.element_accumulator = element_accumulator - - if element_workspace is None: - self.element_workspace = element_accumulator - else: - self.element_workspace = element_workspace - - if element_compute is None: - self.element_compute = element_accumulator - else: - self.element_compute = element_compute - - self.element_output = C.element - self.C: TensorDescription = C - - # Reduce op processing size - self.count: int = count - - # Number of partitions to reduce per stage - self.partitions_per_stage: int = partitions_per_stage - - self.rt_module: ReductionRT = ReductionRT(self) - self.argument_type = self.rt_module.argument_type - self.epilogue_type = self.rt_module.epilogue_type - - def extended_name(self): - extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}" - - return SubstituteTemplate( - extend_name, - { - "element_workspace": DataTypeNames[self.element_workspace], - "element_accumulator": DataTypeNames[self.element_accumulator], - "element_compute": DataTypeNames[self.element_compute], - "element_output": DataTypeNames[self.element_output], - }, - ) - - def configuration_name(self): - """The full procedural name indicates architecture, extended name, tile size""" - - configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}" - - threadblock = "%dx%d" % ( - self.shape.row, - self.shape.column, - ) - - return SubstituteTemplate( - configuration_name, - { - "extended_name": self.extended_name(), - "threadblock": threadblock, - }, - ) - - def procedural_name(self): - """The full procedural name indicates architeture, extended name, tile size""" - return self.configuration_name() - - def run(self, arguments: ReductionArguments) -> cuda.CUresult: - """ - Configure and launch the cuda kernel with input arguments - """ - launch_config = self.rt_module.plan(arguments) - - host_workspace = arguments.host_workspace - device_workspace = None - - err = self.rt_module.run( - host_workspace, - device_workspace, - launch_config, - arguments.stream - ) - - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - - return err - - -class EmitReductionInstance: - def __init__(self, operation_suffix="") -> None: - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/device/gemm.h", - "cutlass/gemm/device/gemm_universal_adapter.h", - "cutlass/gemm/kernel/default_gemm_universal.h", - "cutlass/reduction/kernel/reduce_split_k.h", - "cutlass/reduction/thread/reduction_operators.h", - ] - self.template = """ -// Reduction kernel instance -using ${operation_name}_base = -typename cutlass::reduction::kernel::ReduceSplitK< - cutlass::MatrixShape<${shape_row}, ${shape_column}>, - ${epilogue_functor}, - cutlass::reduction::thread::ReduceAdd< - ${element_accumulator}, - ${element_output}, - ${count}>, - ${partition_per_stage}>; - -struct ${operation_name}${operation_suffix}: - public ${operation_name}_base { }; - """ - - def emit(self, operation: ReductionOperation): - vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128) - epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element] - - values = { - "operation_name": operation.configuration_name(), - "operation_suffix": self.operation_suffix, - "shape_row": str(operation.shape.row), - "shape_column": str(operation.shape.column), - "epilogue_functor": operation.epilogue_functor.emit(), - "element_output": DataTypeTag[operation.element_output], - "epilogue_vector_length": str(epilogue_vector_length), - "element_accumulator": DataTypeTag[operation.element_accumulator], - "element_compute": DataTypeTag[operation.element_compute], - "element_workspace": DataTypeTag[operation.element_workspace], - "count": str(operation.count), - "partition_per_stage": str(operation.partitions_per_stage), - } - - return SubstituteTemplate(self.template, values) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py deleted file mode 100644 index fffa03360f7e0eb2f3a2a20e5c8a4e04d009bee9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py +++ /dev/null @@ -1,35 +0,0 @@ -################################################################################ -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]" - -Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py deleted file mode 100644 index 0bae3bac1163c55a698dfc8722c62ac85cb25abf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -################################################################################ -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py deleted file mode 100644 index 9ed4096a6f4b772a58702c2f4b089cc32d707614..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py +++ /dev/null @@ -1,126 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility functions for interacting with the device -""" -from __future__ import annotations - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") - -import cutlass_cppgen -from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor - - -def check_cuda_errors(result: list): - """ - Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise, - returns the result contained in the remaining fields of `result`. - - :param result: the results of the `cudart` method, consisting of an error code and any method results - :type result: list - - :return: non-error-code results from the `results` parameter - """ - # `result` is of the format : (cudaError_t, result...) - err = result[0] - if err.value: - raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err))) - - if len(result) == 1: - return None - elif len(result) == 2: - return result[1] - else: - return result[1:] - - -def device_cc(device: int = -1) -> int: - """ - Returns the compute capability of the device with ID `device`. - - :param device: ID of the device to query - :type device: int - - :return: compute capability of the queried device (e.g., 80 for SM80) - :rtype: int - """ - if device == -1: - device = cutlass_cppgen.device_id() - - deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) - major = str(deviceProp.major) - minor = str(deviceProp.minor) - return int(major + minor) - - -def device_sm_count(device: int = -1): - if device == -1: - device = cutlass_cppgen.device_id() - err, device_sm_count = cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise Exception( - "Failed to retireve SM count. " - f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" - ) - - return device_sm_count - - -def to_device_ptr(tensor) -> cuda.CUdeviceptr: - """ - Converts a tensor to a CUdeviceptr - - :param tensor: tensor to convert - :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int - - :return: device pointer - :rtype: cuda.CUdeviceptr - """ - if is_numpy_tensor(tensor): - ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) - elif is_torch_tensor(tensor): - ptr = cuda.CUdeviceptr(tensor.data_ptr()) - elif is_cupy_tensor(tensor): - ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) - elif isinstance(tensor, cuda.CUdeviceptr): - ptr = tensor - elif isinstance(tensor, int): - ptr = cuda.CUdeviceptr(tensor) - else: - raise NotImplementedError(tensor) - - return ptr diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py deleted file mode 100644 index 8e4121b59e57e26e8a32022916089e0916db4988..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.emit.pytorch import pytorch diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py deleted file mode 100644 index 58f94e15148f934c92318b586d63b669757ed5f0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py +++ /dev/null @@ -1,267 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Common utilities for emitting CUTLASS kernels -""" - -import cutlass_cppgen - -# Strings used for printing information about the generation of emitted scripts -_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)" - - -_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR} -""" - - -_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR} -""" - -_CUTLASS_KERNEL_ARGS_2x = """ - typename DeviceKernel::Arguments arguments { - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, // problem size - 1, - {alpha, beta}, - A, B, C, D, - 0, 0, 0, 0, // batch strides - DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda - DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb - DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc - DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd - }; -""" - -_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """ - typename DeviceKernel::Arguments arguments { - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, // problem size - 1, - {alpha, beta}, - A, B, C, D, - 0, 0, 0, 0, // batch strides - DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda - DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb - DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc - DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd - -1 // avail_sms - }; -""" - -_CUTLASS_KERNEL_RUN_GEMM_2x = """ -using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; - -cutlass::Status ${name}_kernel_run(int M, int N, int K, - const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, - ElementCompute alpha, ElementCompute beta) { - ${args} - size_t workspace_size = DeviceKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - DeviceKernel gemm_op; - cutlass::Status status = gemm_op.initialize(arguments, - workspace.get(), - nullptr); // CUDA stream - - if (status != cutlass::Status::kSuccess) { - return status; - } - - status = gemm_op(); - return status; -} -""" - -_CUTLASS_KERNEL_RUN_GEMM_3x = """ -using StrideA = typename DeviceKernel::GemmKernel::StrideA; -using StrideB = typename DeviceKernel::GemmKernel::StrideB; -using StrideC = typename DeviceKernel::GemmKernel::StrideC; -using StrideD = typename DeviceKernel::GemmKernel::StrideD; - -using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; - -cutlass::Status ${name}_kernel_run( - int M, int N, int K, int L, - const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, - ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) { - - typename DeviceKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, L}, // problem size - { - A, // ptrA - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A - B, // ptrB - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B - }, - { - {alpha, beta}, - C, // ptrC - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C - D, // ptrD - cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D - }, - hw_info - }; - - size_t workspace_size = DeviceKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - DeviceKernel gemm_op; - cutlass::Status status = gemm_op.run(arguments, - workspace.get(), - nullptr); // CUDA stream - - return status; -} -""" - - -_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """ -using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; - -int threadblock_count = DeviceKernel::sufficient(); - -cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes, - DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D, - int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd, - ElementCompute alpha, ElementCompute beta) { - - typename DeviceKernel::Arguments arguments { - problem_sizes, - problem_count, - threadblock_count, - {alpha, beta}, - A, B, C, D, - lda, ldb, ldc, ldd - }; - - size_t workspace_size = DeviceKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - DeviceKernel gemm_op; - cutlass::Status status = gemm_op.initialize(arguments, - workspace.get(), - nullptr); // CUDA stream - - if (status != cutlass::Status::kSuccess) { - return status; - } - - status = gemm_op(); - return status; -} -""" - - -_CUTLASS_KERNEL_RUN_CONV2D_2x = """ - -using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel; -namespace { -using TensorRefA = typename UnderlyingKernel::TensorRefA; -using TensorRefB = typename UnderlyingKernel::TensorRefB; -using TensorRefC = typename UnderlyingKernel::TensorRefC; -using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute; -} - -template -TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){ - cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord); - TensorRef tensor_ref(ptr, layout); - return tensor_ref; -} - -cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size, - UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B, - UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D, - ElementCompute alpha, ElementCompute beta, std::string split_k_mode, - cudaStream_t stream, int device_id=0) { - // create the tensor references - cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent( - cutlass::conv::Operator::k${conv_kind_name}, *problem_size - ); - cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent( - cutlass::conv::Operator::k${conv_kind_name}, *problem_size - ); - cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent( - cutlass::conv::Operator::k${conv_kind_name}, *problem_size - ); - - TensorRefA tensor_ref_A = get_tensor_ref(tensor_coord_A, A); - TensorRefB tensor_ref_B = get_tensor_ref(tensor_coord_B, B); - TensorRefC tensor_ref_C = get_tensor_ref(tensor_coord_C, C); - TensorRefC tensor_ref_D = get_tensor_ref(tensor_coord_C, D); - - cutlass::conv::SplitKMode mode; - if (split_k_mode == "serial") { - mode = cutlass::conv::SplitKMode::kSerial; - } else if (split_k_mode == "parallel") { - mode = cutlass::conv::SplitKMode::kParallel; - } else { - throw std::runtime_error("Invalid split_k_mode: " + split_k_mode); - } - - typename DeviceKernel::Arguments arguments{ - *problem_size, - tensor_ref_A, - tensor_ref_B, - tensor_ref_C, - tensor_ref_D, - {alpha, beta}, - mode - }; - - DeviceKernel implicit_gemm_op; - - size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); - - void* workspace_ptr = device_memory_allocation(workspace_size, device_id); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - return status; - } - - status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream); - if (status != cutlass::Status::kSuccess) { - return status; - } - - // - // Launch initialized CUTLASS kernel - // - status = implicit_gemm_op(stream); - - return status; -} -""" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py deleted file mode 100644 index fe96f3ede11163da01520f972eb97282a2ab2b14..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py +++ /dev/null @@ -1,936 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel. -If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method. - -Example usage with JIT compilation: - -.. highlight:: python -.. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) - op = plan.construct() - mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) - - # Generate inputs for the GEMM - A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] - - # Run the module - D = mod.run(A, B, C) - - -Example usage without JIT compilation: - -.. highlight:: python -.. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) - op = plan.construct() - cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') - -After this call, the directory ``output`` contains ``setup.py``, -``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from -within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``. - -The module can later be used in Python via: - -.. highlight:: python -.. code-block:: python - - import torch - import cutlass_gemm - - # Generate inputs for the GEMM - A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] - - # Run the module - D = cutlass_gemm.run(A, B, C) -""" - -import logging -import os - -from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate - -from cutlass_cppgen import CUTLASS_PATH, logger, swizzle -from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal -from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation -from cutlass_cppgen.backend.library import ApiVersion -from cutlass_cppgen.emit import common -from cutlass_cppgen.utils.datatypes import is_torch_available - -if is_torch_available(): - import torch - - -_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ -#include -#include -#include -#include -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" - -// helper function allocating the memory -void* device_memory_allocation(size_t size, int device_id=0) { - if (size > 0) { - torch::Device device(torch::kCUDA, device_id); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device); - at::Tensor device_tensor = torch::empty({(long)size,}, options); - return reinterpret_cast(device_tensor.data_ptr()); - } else { - return nullptr; - } -} - -${includes} -${declaration} -${impl} -""" - -_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ -#include -#include -#include - -// CUDA forward declarations -at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f); - -// C++ interface -at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f) { - return ${name}_kernel(A, B, C, alpha, beta); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", py::overload_cast, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); -} -""" - -_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ -#include -#include -#include - -// CUDA forward declarations -std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f); - -// C++ interface -std::vector ${name}(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f) { - return ${name}_kernel(A, B, C, alpha, beta); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", py::overload_cast&, const std::vector&, at::optional>, float, float>(&${name}), - py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); -} -""" - -_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ -#include -#include -#include - -// CUDA forward declarations -at::Tensor ${name}_kernel( - const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, - float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1); - -// C++ interface -at::Tensor ${name}( - const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, - float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1) { - return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", - py::overload_cast< - const at::Tensor&, const at::Tensor&, at::optional, - std::tuple, std::tuple, std::tuple, float, float, std::string, int>( - &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, - py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), - py::arg("alpha") = 1.f, py::arg("beta") = 0.f, - py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); -} -""" - -_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ -#include -#include -#include - -// CUDA forward declarations -at::Tensor ${name}_kernel( - std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, - float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1); - -// C++ interface -at::Tensor ${name}( - std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, - float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1) { - return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", - py::overload_cast< - std::tuple, const at::Tensor&, const at::Tensor&, at::optional, - std::tuple, std::tuple, std::tuple, float, float, std::string, int>( - &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, - py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), - py::arg("alpha") = 1.f, py::arg("beta") = 0.f, - py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); -} -""" - -_PYTORCH_GEMM_INCLUDES = { - ApiVersion.v2x: """ -#include "cutlass/gemm/device/gemm_universal.h" -""", - ApiVersion.v3x: """ -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/util/packed_stride.hpp" -""", -} - -_PYTORCH_GROUPED_GEMM_INCLUDES = """ -#include "cutlass/gemm/kernel/default_gemm_grouped.h" -#include "cutlass/gemm/device/gemm_grouped.h" -""" - -_PYTORCH_CONV2D_INCLUDES = """ -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -#include "cutlass/conv/device/implicit_gemm_convolution.h" -""" - -_CUTLASS_TYPE_TO_TORCH_TYPE = { - DataType.f16: "torch::kF16", - DataType.f32: "torch::kF32", - DataType.f64: "torch::kF64", - DataType.s8: "torch::kI8", - DataType.s32: "torch::kI32", - DataType.bf16: "torch::kBFloat16", -} - -_PYTORCH_GEMM_IMPL_TEMPLATE_2x = ( - common._CUTLASS_KERNEL_RUN_GEMM_2x - + """ -at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { - int M = A.size(0); - int N = B.size(1); - int K = A.size(1); - - typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? - nullptr : - reinterpret_cast(C->contiguous().data_ptr()); - at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); - - cutlass::Status status = ${name}_kernel_run(M, N, K, - reinterpret_cast(A.contiguous().data_ptr()), - reinterpret_cast(B.contiguous().data_ptr()), - ptrC, - reinterpret_cast(D.contiguous().data_ptr()), - ElementCompute(alpha), ElementCompute(beta)); - - TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); - return D; -} -""" -) - -_PYTORCH_GEMM_IMPL_TEMPLATE_3x = ( - common._CUTLASS_KERNEL_RUN_GEMM_3x - + """ -bool hw_info_queried = false; -cutlass::KernelHardwareInfo hw_info; - -at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { - int M = A.size(0); - int N = B.size(1); - int K = A.size(1); - int L = 1; - - // Query hardware info if we haven't already - if (!hw_info_queried) { - hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - } - - typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? - nullptr : - reinterpret_cast(C->contiguous().data_ptr()); - at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); - - cutlass::Status status = ${name}_kernel_run(M, N, K, L, - reinterpret_cast(A.contiguous().data_ptr()), - reinterpret_cast(B.contiguous().data_ptr()), - ptrC, - reinterpret_cast(D.contiguous().data_ptr()), - ElementCompute(alpha), ElementCompute(beta), - hw_info); - - TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); - return D; -} -""" -) - - -_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = ( - common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x - + """ -std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C, float alpha, float beta) { - size_t num = A.size(); - - // To avoid performing many small cudaMallocs and host-to-device copies, - // we serialize the grouped GEMM arguments on the host, allocate one - // large chunk of device memory, and perform a single cudaMemcpy to - // copy the host data to the device. Allocation overheads could be - // avoided by using a memory pool. - - // Calculate the total size of the data to be copied from host to device - size_t total_size = sizeof(cutlass::gemm::GemmCoord) + - sizeof(DeviceKernel::ElementA*) + - sizeof(DeviceKernel::ElementB*) + - sizeof(DeviceKernel::ElementC*) + - sizeof(DeviceKernel::ElementC*) + - sizeof(int64_t) + - sizeof(int64_t) + - sizeof(int64_t); - total_size *= num; - - // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple - // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system). - // To ensure that we don't end up having misaligned loads in the kernel, - // we pad to the nearest multiple of 8. - // - // Note that, even on a 32-bit system (for which sizeof(X*) will not equal - // sizeof(int64_t)), only padding between the list of GemmCoords and the - // list of ptr_As is sufficient because the set of four equal-length lists of pointers - // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always - // start on a multiple of 8. - int64_t padding = 8 - (total_size % 8); - total_size += padding; - - uint8_t* host_data = new uint8_t[total_size]; - cutlass::DeviceAllocation device_data(total_size); - - uint8_t* start = host_data; - cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast(start); - - // Apply the padding after the list of GemmCoords - start += num * sizeof(cutlass::gemm::GemmCoord) + padding; - - int64_t ptr_A_offset = start - host_data; - DeviceKernel::ElementA** ptr_A_host = reinterpret_cast(start); - start += num * sizeof(DeviceKernel::ElementA*); - - int64_t ptr_B_offset = start - host_data; - DeviceKernel::ElementB** ptr_B_host = reinterpret_cast(start); - start += num * sizeof(DeviceKernel::ElementB*); - - int64_t ptr_C_offset = start - host_data; - DeviceKernel::ElementC** ptr_C_host = reinterpret_cast(start); - start += num * sizeof(DeviceKernel::ElementC*); - - int64_t ptr_D_offset = start - host_data; - DeviceKernel::ElementC** ptr_D_host = reinterpret_cast(start); - start += num * sizeof(DeviceKernel::ElementC*); - - int64_t lda_offset = start - host_data; - int64_t* lda_host = reinterpret_cast(start); - start += num * sizeof(int64_t); - - int64_t ldb_offset = start - host_data; - int64_t* ldb_host = reinterpret_cast(start); - start += num * sizeof(int64_t); - - int64_t ldc_offset = start - host_data; - int64_t* ldc_host = reinterpret_cast(start); - start += num * sizeof(int64_t); - - std::vector D(num); - - bool need_C = (C != at::nullopt) && (beta != 0.f); - for (size_t i = 0; i < num; ++i) { - int M = A[i].size(0); - int N = B[i].size(1); - int K = A[i].size(1); - *(problem_sizes_host + i) = {M, N, K}; - *(ptr_A_host + i) = reinterpret_cast(A[i].contiguous().data_ptr()); - *(ptr_B_host + i) = reinterpret_cast(B[i].contiguous().data_ptr()); - - if (need_C) { - *(ptr_C_host + i) = reinterpret_cast(C->at(i).contiguous().data_ptr()); - } - else { - *(ptr_C_host + i) = nullptr; - } - - D[i] = B[i].new_empty({M, N}, ${torch_type_C}); - *(ptr_D_host + i) = reinterpret_cast(D[i].contiguous().data_ptr()); - - *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0); - *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0); - *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0); - } - - device_data.copy_from_host(host_data); - - cutlass::Status status = ${name}_kernel_run( - num, - reinterpret_cast(device_data.get()), - reinterpret_cast(device_data.get() + ptr_A_offset), - reinterpret_cast(device_data.get() + ptr_B_offset), - reinterpret_cast(device_data.get() + ptr_C_offset), - reinterpret_cast(device_data.get() + ptr_D_offset), - reinterpret_cast(device_data.get() + lda_offset), - reinterpret_cast(device_data.get() + ldb_offset), - reinterpret_cast(device_data.get() + ldc_offset), - reinterpret_cast(device_data.get() + ldc_offset), - ElementCompute(alpha), ElementCompute(beta)); - - delete[] host_data; - - TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); - return D; -} -""" -) - -_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """ - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - cutlass::Status status = ${name}_kernel_run( - &problem_size, - reinterpret_cast(A.data_ptr()), - reinterpret_cast(B.data_ptr()), - ptrC, - reinterpret_cast(D.data_ptr()), - alpha, beta, - split_k_mode, stream, B.device().index()); - - TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); - return D; -} -""" - -_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = ( - common._CUTLASS_KERNEL_RUN_CONV2D_2x - + """ -at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, - float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) { - int N, H, W, C_, K, R, S, P, Q; - N = A.size(0); - C_ = A.size(1); - H = A.size(2); - W = A.size(3); - - K = B.size(0); - R = B.size(2); - S = B.size(3); - - cutlass::conv::Conv2dProblemSize problem_size( - cutlass::Tensor4DCoord(N, H, W, C_), - cutlass::Tensor4DCoord(K, R, S, C_), - cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), - cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), - cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), - cutlass::conv::Mode::kCrossCorrelation, - split_k_slices - ); - - P = problem_size.P; - Q = problem_size.Q; - - typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? - nullptr : - reinterpret_cast(C->data_ptr()); - - torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); - at::Tensor D = torch::zeros({N, K, P, Q}, options); -""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x -) - - -_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = ( - common._CUTLASS_KERNEL_RUN_CONV2D_2x - + """ -at::Tensor ${name}_kernel(std::tuple input_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1) { - int N, H, W, C_, K, R, S; - N = std::get<0>(input_size); - C_ = std::get<1>(input_size); - H = std::get<2>(input_size); - W = std::get<3>(input_size); - - K = B.size(0); - R = B.size(2); - S = B.size(3); - - cutlass::conv::Conv2dProblemSize problem_size( - cutlass::Tensor4DCoord(N, H, W, C_), - cutlass::Tensor4DCoord(K, R, S, C_), - cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), - cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), - cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), - cutlass::conv::Mode::kCrossCorrelation, - split_k_slices - ); - - typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? - nullptr : - reinterpret_cast(C->data_ptr()); - - torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); - at::Tensor D = torch::empty({N, C_, H, W}, options); -""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x -) - - -_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = ( - common._CUTLASS_KERNEL_RUN_CONV2D_2x - + """ -at::Tensor ${name}_kernel(std::tuple weight_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, - std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, - std::string split_k_mode="serial", int split_k_slices=1) { - int N, H, W, C_, K, R, S; - K = std::get<0>(weight_size); - C_ = std::get<1>(weight_size); - R = std::get<2>(weight_size); - S = std::get<3>(weight_size); - - N = B.size(0); - H = B.size(2); - W = B.size(3); - - cutlass::conv::Conv2dProblemSize problem_size( - cutlass::Tensor4DCoord(N, H, W, C_), - cutlass::Tensor4DCoord(K, R, S, C_), - cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), - cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), - cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), - cutlass::conv::Mode::kCrossCorrelation, - split_k_slices - ); - - typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? - nullptr : - reinterpret_cast(C->data_ptr()); - - torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); - at::Tensor D = torch::empty({K, C_, R, S}, options); -""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x -) - - -_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name='${name}', - ext_modules=[ - CUDAExtension('${name}', [ - '${name}.cpp', - '${name}_kernel.cu', - ], - include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], - extra_compile_args={ - 'cxx': ['-std=c++17'], - 'nvcc': ['-std=c++17', ${extra_compile_args}], - }, - libraries=['cuda'] - ), - ], - cmdclass={ - 'build_ext': BuildExtension - }) - -""" - - -def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""): - """ - Generates a setup.py file for the extension - - :param name: name of the module to generate - :type name: str - :param sourcedir: directory to which generated source files should be written - :type sourcedir: str - :param extra_compile_args: additional arguments to pass to setup.py - :type extra_args: str - """ - setup_py_file = os.path.join(sourcedir, "setup.py") - setup_source = SubstituteTemplate( - _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args} - ) - with open(setup_py_file, "w") as outfile: - outfile.write(setup_source) - - -class _ArchListSetter: - """ - Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST`` - environment variable when building a PyTorch CUDA module. - - ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch - CUDA module should be compiled. - - For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of - ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the - compilation of the module. - - This utility wraps the building of a PyTorch CUDA module with a setting of this environment - variable according to the current compute capability being targetted. - - Example usage: - - .. highlight:: python - .. code-block:: python - - # Temporarily set TORCH_CUDA_ARCH_LIST="8.0" - with _ArchListSetter(80): - # Perform JIT compilation and loading of the module - mod = torch.utils.cpp_extension.load(...) - - :param cc: compute capability - :type cc: int - """ - - _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST" - - def __init__(self, cc: int): - self.cc_str = ".".join(list(str(cc))) - - def __enter__(self): - """ - Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc`` - """ - self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST) - os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str - - return self - - def __exit__(self, exc_type, exc_val, traceback): - """ - Restores the old value of TORCH_CUDA_ARCH_LIST - """ - if self.old_arch_list is None: - del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] - else: - os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list - - -def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): - """ - JIT compiles and loads a PyTorch CUDA extension. - - :param name: name of the module to generate - :type name: str - :param cc: compute capability of the device the module should target - :type cc: int - :param cpp_file: path to file containing extension's C++ interface - :type cpp_file: str - :param cuda_file: path to file containing extension's CUDA interface - :type cuda_file: str - - :return: loaded PyTorch module - """ - - from torch.utils.cpp_extension import load - - extra_cuda_cflags = ["-std=c++17"] - if cc in [90, 100, 101, 103]: - # PyTorch does not currently add the sm_90a target when compute capability - # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. - extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a") - - with _ArchListSetter(cc): - jitmodule = load( - name, - [cpp_file, cuda_file], - extra_cuda_cflags=extra_cuda_cflags, - extra_include_paths=[ - os.path.join(CUTLASS_PATH, "include"), - os.path.join(CUTLASS_PATH, "tools/util/include"), - ], - extra_ldflags=["-lcuda"], - verbose=(logger.level == logging.DEBUG) - ) - return jitmodule - - -def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): - """ - Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM - specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time - compiled, loaded, and returned. - - :param op: operation to emit in the module - :param name: name of the module to generate - :type name: str - :param cc: compute capability of the device the module should target - :type cc: int - :param jit: whether the module should be just-in-time compiled - :type jit: bool - :param sourcedir: directory to which generated source files should be written - :type sourcedir: str - - :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise - """ - if sourcedir != "" and not os.path.isdir(sourcedir): - os.makedirs(sourcedir) - - cuda_file = os.path.join(sourcedir, name + "_kernel.cu") - extra_kw = {} - if op.api == ApiVersion.v3x: - impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x - else: - impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x - if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK: - extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K - else: - extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x - impl_template = ( - _PYTORCH_GEMM_IMPL_TEMPLATE_3x - if op.api == ApiVersion.v3x - else _PYTORCH_GEMM_IMPL_TEMPLATE_2x - ) - cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) - cuda_source = SubstituteTemplate( - _PYTORCH_CUDA_TEMPLATE, - { - "includes": _PYTORCH_GEMM_INCLUDES[op.api], - "declaration": op.rt_module.emit(), - "procedural_name": op.procedural_name(), - "impl": cuda_impl, - "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], - }, - ) - with open(cuda_file, "w") as outfile: - outfile.write(cuda_source) - - cpp_file = os.path.join(sourcedir, name + ".cpp") - cpp_source = SubstituteTemplate( - _PYTORCH_GEMM_CPP_TEMPLATE, - {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"}, - ) - with open(cpp_file, "w") as outfile: - outfile.write(cpp_source) - - extra_compile_args = "" - if cc in [90, 100, 101, 103]: - extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'" - _generate_setup(name, sourcedir, extra_compile_args) - - if jit: - return _jit(name, cc, cpp_file, cuda_file) - - return None - - -def _pytorch_grouped_gemm( - op, name: str, cc: int, jit: bool = False, sourcedir: str = "" -): - """ - Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM - specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time - compiled, loaded, and returned. - - :param op: operation to emit in the module - :param name: name of the module to generate - :type name: str - :param cc: compute capability of the device the module should target - :type cc: int - :param jit: whether the module should be just-in-time compiled - :type jit: bool - :param sourcedir: directory to which generated source files should be written - :type sourcedir: str - - :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise - """ - if op.api != ApiVersion.v2x: - raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x") - - if sourcedir != "" and not os.path.isdir(sourcedir): - os.makedirs(sourcedir) - - cuda_file = os.path.join(sourcedir, name + "_kernel.cu") - cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name}) - cuda_source = SubstituteTemplate( - _PYTORCH_CUDA_TEMPLATE, - { - "includes": _PYTORCH_GROUPED_GEMM_INCLUDES, - "declaration": op.rt_module.emit(), - "procedural_name": op.procedural_name(), - "impl": cuda_impl, - "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], - }, - ) - with open(cuda_file, "w") as outfile: - outfile.write(cuda_source) - - cpp_file = os.path.join(sourcedir, name + ".cpp") - cpp_source = SubstituteTemplate( - _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE, - {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"}, - ) - with open(cpp_file, "w") as outfile: - outfile.write(cpp_source) - - _generate_setup(name, sourcedir) - - if jit: - return _jit(name, cc, cpp_file, cuda_file) - - return None - - -def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): - """ - Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d - specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time - compiled, loaded, and returned. - - :param op: operation to emit in the module - :param name: name of the module to generate - :type name: str - :param cc: compute capability of the device the module should target - :type cc: int - :param jit: whether the module should be just-in-time compiled - :type jit: bool - :param sourcedir: directory to which generated source files should be written - :type sourcedir: str - - Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or - weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions - for H/W/R/S given the same P/Q. - - :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise - """ - if sourcedir != "" and not os.path.isdir(sourcedir): - os.makedirs(sourcedir) - cuda_file = os.path.join(sourcedir, name + "_kernel.cu") - extra_kw = {} - if op.conv_kind == ConvKind.Fprop: - impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x - cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE - elif op.conv_kind == ConvKind.Dgrad: - impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x - cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE - elif op.conv_kind == ConvKind.Wgrad: - impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x - cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE - extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize() - extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element] - cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) - cuda_source = SubstituteTemplate( - _PYTORCH_CUDA_TEMPLATE, - { - "includes": _PYTORCH_CONV2D_INCLUDES, - "declaration": op.rt_module.emit(), - "procedural_name": op.procedural_name(), - "impl": cuda_impl, - "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], - }, - ) - with open(cuda_file, "w") as outfile: - outfile.write(cuda_source) - - cpp_file = os.path.join(sourcedir, name + ".cpp") - cpp_source = SubstituteTemplate( - cpp_template, - {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"}, - ) - with open(cpp_file, "w") as outfile: - outfile.write(cpp_source) - - _generate_setup(name, sourcedir) - - if jit: - return _jit(name, cc, cpp_file, cuda_file) - - return None - - -def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): - """ - Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel - specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time - compiled, loaded, and returned. - - The result of this method is files within ``sourcedir`` that can be used for building - a PyTorch module. - - :param op: operation to emit in the module - :param name: name of the module to generate - :type name: str - :param cc: compute capability of the device the module should target - :type cc: int - :param jit: whether the module should be just-in-time compiled - :type jit: bool - :param sourcedir: directory to which generated source files should be written - :type sourcedir: str - - :return: loaded PyTorch module (if ``jit=True``) or None - """ - device_op = op.device_op() - if isinstance(op, GemmOperationUniversal): - return _pytorch_gemm(device_op, name, cc, jit, sourcedir) - elif isinstance(op, GemmOperationGrouped): - return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir) - elif isinstance(op, Conv2dOperation): - return _pytorch_conv2d(device_op, name, cc, jit, sourcedir) - else: - raise Exception( - f"Operation type {type(op)} is not currently supported for PyTorch emission." - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py deleted file mode 100644 index faf6896e99ba78130ede8e09be9b9115e9169541..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.epilogue.epilogue import ( - get_activations, - get_activation_epilogue, - gelu, - hardswish, - identity, - leaky_relu, - relu, - sigmoid, - silu, - tanh, - trace -) - -from cutlass_cppgen.epilogue.evt_ops import ( - max, - multiply_add, - sum, - permute, - reshape, - maximum, - minimum, - exp -) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py deleted file mode 100644 index a3a17506ee2be609ed8d5b299114df52c55ca0cf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py +++ /dev/null @@ -1,176 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Registry of elementwise epilogues - -Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via -code like the following for GEMM: - -.. highlight:: python -.. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) - plan.activation = cutlass_cppgen.epilogue.relu -""" - -from cutlass_cppgen.backend import epilogue, device_cc - - -gelu = epilogue.gelu -hardswish = epilogue.hardswish -identity = epilogue.identity -leaky_relu = epilogue.leaky_relu -relu = epilogue.relu -sigmoid = epilogue.sigmoid -silu = epilogue.silu -tanh = epilogue.tanh - - -_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh] - - -def get_activations() -> list: - """ - Returns a list of available activation functions - - :return: list of available activation functions - :rtype: list - """ - return _activations - - -def get_activation_epilogue( - activation, - element_output, - elements_per_access, - element_accumulator, - element_compute, -): - """ - Return an epilogue corresponding to the activation function, data types, and alignment - used in the kernel - - :param activation: elementwise activation function to use - :param element_output: data type of the output - :param elements_per_access: alignment of operand C of the kernel - :type elements_per_access: int - :param element_accumulator: data type of the accumulated output C - :param element_compute: data type in which compute operations should be performed - - :return: epilogue functor - """ - if activation not in _activations: - raise Exception( - f"Unsupported activation type {activation}. Available activations are: {_activations}" - ) - - if activation == identity: - return epilogue.LinearCombination( - element_output, elements_per_access, element_accumulator, element_compute - ) - else: - return epilogue.LinearCombinationGeneric( - activation, - element_output, - elements_per_access, - element_accumulator, - element_compute, - ) - - -""" -Frontend for EVT that generates epilogue functor through tracing the input function -""" -from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend - - -def trace(fn, example_tensors, **kwargs): - """ - Trace `fn(**example_tensors)` and generates epilogue visitor - - :param fn or str: Python callable or string of the epilogue function - :param example_tensors: example inputs for fn - :type example_tensors: dict - - .. hightlight:: python - .. code-block:: python - import cutlass_cppgen.backend.evt - - # Define epilogue function as Python callable - def example_fn(accum, C, alpha, beta, gamma): - D = ((accum + C) * alpha - gamma) / beta - return D - - # Define the example tensors - example_inputs = { - "accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), - "C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), - "alpha": 1.5, - "beta": 0.5, - "gamma": 2.5, - "D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda") - } - - # Generate the epilogue functor - epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs) - """ - if callable(fn): - class EpilogueFunctor(PythonASTFrontend): - def __init__(self, cc=None, **kwargs): - if not cc: - cc = device_cc() - super().__init__(cc, **kwargs) - pass - setattr(EpilogueFunctor, "__call__", staticmethod(fn)) - - epilogue_functor = EpilogueFunctor(**kwargs) - epilogue_functor.trace(example_tensors) - return epilogue_functor - elif isinstance(fn, str): - class EpilogueFunctor(PythonASTFrontend): - def __init__(self, cc=None, **kwargs): - self.source = textwrap.dedent(fn) - if not cc: - cc = device_cc() - super().__init__(cc, **kwargs) - - def parse(self, example_inputs) -> None: - self.example_inputs = example_inputs - self.ast = ast.parse(self.source) - self.visit(self.ast) - - epilogue_functor = EpilogueFunctor(**kwargs) - epilogue_functor.trace(example_tensors) - return epilogue_functor - else: - raise NotImplementedError("Expect a callable Python function") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py deleted file mode 100644 index 7d8e2c01286886ffc936052c84205a60a5d869fb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Collection of builtin functions used for host reference in EVT -""" - -import numpy as np - -from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor - -if is_torch_available(): - import torch - - -def multiply_add(x, y, z): - return x * y + z - - -def sum(x, dim): - if is_numpy_tensor(x): - return x.sum(axis=tuple(dim)) - elif is_torch_tensor(x): - return torch.sum(x, dim) - - -def max(x, dim): - if is_numpy_tensor(x): - return x.max(axis=tuple(dim)) - elif is_torch_tensor(x): - return torch.amax(x, dim) - - -def maximum(x, y): - if is_numpy_tensor(x): - return np.maximum(x, y) - elif is_torch_tensor(x): - return torch.maximum(x, torch.tensor(y)) - - -def minimum(x, y): - if is_numpy_tensor(x): - return np.minimum(x, y) - elif is_torch_tensor(x): - return torch.minimum(x, torch.tensor(y)) - -def exp(x): - if is_numpy_tensor(x): - return np.exp(x) - elif is_torch_tensor(x): - return torch.exp(x) - - -############################################################################## -# Layout manipulate nodes -############################################################################## - -def permute(x, indices: tuple): - if is_numpy_tensor(x): - return np.transpose(x, axes=indices) - elif is_torch_tensor(x): - return x.permute(*indices) - - -def reshape(x, new_shape: tuple): - if is_numpy_tensor(x): - return np.reshape(x, newshape=new_shape) - elif is_torch_tensor(x): - return x.view(new_shape) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py deleted file mode 100644 index f5ea04419955f6a71225b6daaeab884dcc4e3399..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py +++ /dev/null @@ -1,569 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Classes containing valid operations for a given compute capability and data types. -""" - -from itertools import combinations_with_replacement -import logging - -import cutlass_library -from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode - -import cutlass_cppgen -from cutlass_cppgen.utils.check import valid_stage_count -from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op - - -_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100] - - -class KernelsForDataType: - """ - Container class for keeping track of kernels that correspond to a particular combination - of data types for operands A, B, and accumulator - """ - - def __init__(self, datatype_comb: tuple, layout_comb: tuple): - self.datatype_comb = datatype_comb - self.layout_comb = layout_comb - self.math_operations = set() - - # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment - # constraint for the data type combination - self.kernels_by_alignment = {} - - def add(self, operation): - """ - Add an operation to the list of supported kernels - """ - alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}" - if alignment_key not in self.kernels_by_alignment: - self.kernels_by_alignment[alignment_key] = [] - self.kernels_by_alignment[alignment_key].append(operation) - self.math_operations.add(operation.tile_description.math_instruction.math_operation) - - def alignments(self, operand: str): - """ - Returns an unsorted list of alignments supported by this data type combination - - :param operand: identifier of operand in question (e.g., A, B, C) - :type operand: str - - :return: unsorted list of alignments supported by this data type combination - :rtype: list - """ - operand_idx = self._operand_idx(operand) - return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()] - - @property - def all_operations(self): - """ - Returns a list of all operations supported by this data type combination - - :return: list of all operations supported by this data type combination - :rtype: list - """ - ops = [] - for _, alignment_ops in self.kernels_by_alignment.items(): - ops.extend(alignment_ops) - return ops - - def default_operation(self, math_operation: cutlass_cppgen.MathOperation): - key = sorted(list(self.kernels_by_alignment.keys()))[0] - kernels = self.kernels_by_alignment[key] - if math_operation is not None: - kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation] - return kernels[0] - - def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation): - """ - Returns operations satisfying the alignment constraints - - :param alignment_A: alignment constraint of operations to return - :type alignment_A: int - :param alignment_B: alignment constraint of operations to return - :type alignment_B: int - :param alignment_C: alignment constraint of operations to return - :type alignment_C: int - :param math_operation: math operation to consider - :type math_operation: cutlass_cppgen.MathOperation - - :return: list of operations - :rtype: list - """ - key = f"{alignment_A} {alignment_B} {alignment_C}" - - if key not in self.kernels_by_alignment: - og_key = key - # Reconcile A, B, and C alignments by trying to align to the minimum - min_alignment = min(alignment_A, alignment_B, alignment_C) - key = f"{min_alignment} {min_alignment} {min_alignment}" - if key not in self.kernels_by_alignment: - # Finally, go through all available alignment combinations and find - # one for which all values are less than those passed in. - key = None - alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) - for align_A, align_B, align_C in alignments: - if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0: - key = f"{align_A} {align_B} {align_C}" - break - - if key is None: - raise Exception( - f"No operations of alignment {og_key} found for data type and layout " - f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments " - f"are {self.kernels_by_alignment.keys()}" - ) - - ops = self.kernels_by_alignment[key] - if math_operation is not None: - ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation] - return ops - - def _operand_idx(self, key: str) -> int: - operand_list = ["A", "B", "C"] - if key not in operand_list: - raise Exception(f"Unexpected operand {operand}") - - return operand_list.index(key) - - def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int: - """ - Returns the most preferable alignment for a given shape and layout - - :param shape: extent of each dimension of the tensor - :type shape: tuple - :param layout: layout of the tensor - :type layout: cutlass_cppgen.LayoutType - :param operand: descriptor of the operand in question - :type operand: str - - :return: maximum alignment supported by the data type combination and tensor size - :rtype: int - """ - operand_idx = self._operand_idx(operand) - - # Determine the leading dimension of the shape - if layout == cutlass_cppgen.LayoutType.ColumnMajor: - ld = shape[-2] - elif layout == cutlass_cppgen.LayoutType.RowMajor: - ld = shape[-1] - elif layout == cutlass_cppgen.LayoutType.TensorNHWC: - ld = shape[-1] - else: - raise Exception(f"Unexpected or unsupported layout {layout}") - - for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True): - alignment = int(alignments.split(" ")[operand_idx]) - if ld % alignment == 0: - return alignment - - # Default to alignment of 1 if no others match - return 1 - - def sort(self): - """ - Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape - """ - key = lambda op: ( - op.tile_description.threadblock_shape[0] - * op.tile_description.threadblock_shape[1] - * op.tile_description.threadblock_shape[2] - ) - for alignment in self.kernels_by_alignment.keys(): - self.kernels_by_alignment[alignment].sort(key=key, reverse=True) - - def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool: - """ - Returns whether `math_operation` is supported by at least one operation. - - :param math_operation: math operation to consider - :type math_operation: cutlass_cppgen.MathOperation - - :return: whether math_operation is supported by at least one operation - :rtype: bool - """ - return math_operation is None or math_operation in self.math_operations - - -class ArchOptions: - """ - Structure for keeping track of kernels available on a given compute capability - - :param target_cc: compute capability of the device on which kernels will be run - :type target_cc: int - :param kernel_cc: compute capability of the kernels to generate - :type kernel_cc: int - :param operation_kind: type of operation to register - :type operation_kind: cutlass_library.OperationKind - :param gemm_kinds: types of GEMM operations that can be included - :type gemm_kinds: list - :param allowed_math_operations: types of primitive math operations allowed - :type allowed_math_operations: list - """ - - def __init__( - self, - target_cc: int, - kernel_cc: int, - operation_kind: cutlass_library.OperationKind, - gemm_kinds: list, - allowed_math_operations: list = [ - cutlass_library.MathOperation.multiply_add, - cutlass_library.MathOperation.multiply_add_saturate, - cutlass_library.MathOperation.multiply_add_mixed_input_upcast, - cutlass_library.MathOperation.multiply_add_fast_f32 - ] - ): - self.cc = kernel_cc - - # Dictionary with following structure: - # Key: OpcodeClass - # Value: Dictionary with the following structure: - # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType), - # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b)) - # Value: KernelsForDataType - self.operations_by_opclass = {} - self.op_class = None - self.allowed_math_operations = allowed_math_operations - - if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100: - return - - # Identify the method within CUTLASS generator script that generates kernel - # descriptions for the target CC - generate_function_name = "GenerateSM" + str(kernel_cc) - if not hasattr(cutlass_library.generator, generate_function_name): - cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}") - return - generate_function = getattr(cutlass_library.generator, generate_function_name) - - # Initialize a default manifest and populate it with valid kernel descriptions - # for the target CC - args = [ - "--kernels=all", - f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}" - ] - manifest_args = cutlass_library.generator.define_parser().parse_args(args) - manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, cutlass_cppgen._nvcc_version) - - if operation_kind not in manifest.operations: - # No kernels generated for this architecture, this could be because the CUDA - # toolkit is insufficient to support operations in this CC - cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") - return - - # Only one CC should be returned, given the setup above of calling only the generation scripts - # for a given CC - if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]: - raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version " - "is sufficient for the architecture in question.") - - # Iterate through the available operations for this operation kind and - # find available opclasses and data types - for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): - for op in op_list: - - if operation_kind == cutlass_library.OperationKind.Gemm: - if op.gemm_kind not in gemm_kinds: - continue - - mi = op.tile_description.math_instruction - if mi.math_operation not in self.allowed_math_operations: - continue - - # Prune operations that don't fit in shared memory - td = td_from_profiler_op(op) - if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]: - continue - - if mi.opcode_class not in self.operations_by_opclass: - self.operations_by_opclass[mi.opcode_class] = {} - - datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) - layout_comb = (op.A.layout, op.B.layout) - - # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations - if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32): - # TF32 kernels only supported on SM80 and beyond - if self.cc < 80: - continue - elif self.cc == 90 or self.cc == 100: - if (op.A.element != cutlass_library.DataType.f32 - or op.B.element != cutlass_library.DataType.f32 - or op.C.element != cutlass_library.DataType.f32): - continue - - datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32) - - opclass_dict = self.operations_by_opclass[mi.opcode_class] - key = (datatype_comb, layout_comb) - if key not in opclass_dict: - opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb) - opclass_dict[key].add(op) - - # Set the default opclass to TensorOp, if available. Otherwise default to SIMT - if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass: - self.op_class = cutlass_library.OpcodeClass.TensorOp - else: - self.op_class = cutlass_library.OpcodeClass.Simt - - # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels. - # Here, we generate additional versions via a generic TileDescription. - if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass: - self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {} - - if operation_kind == cutlass_library.OperationKind.Gemm: - types = [ - (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8), - (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32), - (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), - (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), - (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), - (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), - ] - - # Add FP8 A/B/C - fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2] - for type_comb in combinations_with_replacement(fp8_types, 3): - types.append(type_comb) - - # Add FP8 A/B with FP32 C - for type_comb in combinations_with_replacement(fp8_types, 2): - types.append(type_comb + (cutlass_cppgen.DataType.f32,)) - - layouts = [ - (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), - (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor), - (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor), - (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor), - ] - elif operation_kind == cutlass_library.OperationKind.Conv2d: - types = [ - (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), - (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), - (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), - (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), - ] - - layouts = [ - (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC), - ] - else: - raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.") - - alignment = 1 - epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination - swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8 - for type_comb in types: - for layout_comb in layouts: - comb = (type_comb, layout_comb) - if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]: - continue - - A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment) - B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment) - C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment) - math_inst = cutlass_library.MathInstruction( - [1, 1, 1], - type_comb[0], - type_comb[1], - type_comb[2], - cutlass_library.OpcodeClass.Simt, - cutlass_library.MathOperation.multiply_add - ) - - td = cutlass_library.TileDescription( - [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024) - - # Prune operations that don't fit in shared memory - if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]: - continue - - new_kernels = KernelsForDataType(type_comb, layout_comb) - - if operation_kind == cutlass_library.OperationKind.Gemm: - new_operation = cutlass_library.manifest.GemmOperation( - cutlass_library.GemmKind.Universal, td.minimum_compute_capability, - td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) - new_kernels.add(new_operation) - elif operation_kind == cutlass_library.OperationKind.Conv2d: - for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: - new_operation = cutlass_library.manifest.Conv2dOperation( - conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td, - A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor, - group_mode=GroupMode.SingleGroup - ) - new_kernels.add(new_operation) - - self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels - - # Sort all operations - for oc in self.operations_by_opclass.keys(): - for comb in self.operations_by_opclass[oc].keys(): - self.operations_by_opclass[oc][comb].sort() - - def opclass_supports_combination( - self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation - ) -> bool: - """ - Returns whether the provided operation class supports the provided data type and layout combination - - :param op_class: operation class to consider - :type op_class: cutlass_library.OpcodeClass - :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator) - :type datatype_comb: tuple[cutlass_library.DataType] - :param layout_comb: tuple of data types for (layout_A, layout_B) - :type layout_comb: tuple[cutlass_library.LayoutType] - :param math_operation: math operation to consider or None if any can be considered - :type math_operation: cutlass_cppgen.MathOperation - - :return: set of operation classes that support the provided data type and layout combination - :rtype: set - """ - if op_class not in self.operations_by_opclass: - raise Exception(f"Unexpected or unsupported operation class {op_class}") - - if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)): - if math_operation is not None: - return operations.supports_math_operation(math_operation) - else: - return True - - return False - - - def supporting_opclasses( - self, - element_a: cutlass_library.DataType, - element_b: cutlass_library.DataType, - element_accumulator: cutlass_library.DataType, - layout_a: cutlass_library.LayoutType, - layout_b: cutlass_library.LayoutType, - math_operation: cutlass_library.MathOperation, - ) -> set: - """ - Returns a set of operation classes that support the provided data type combination - - :param element_a: data type of operand A - :type element_a: cutlass_library.DataType - :param element_b: data type of operand B - :type element_b: cutlass_library.DataType - :param element_accumulator: data type of accumulator - :type element_accumulator: cutlass_library.DataType - :param layout_a: layout of operand A - :type layout_a: cutlass_library.LayoutType - :param layout_b: layout of operand B - :type layout_b: cutlass_library.LayoutType - :param math_operation: math operation to consider - :type math_operation: cutlass_cppgen.MathOperation - - :return: set of operation classes that support the provided data type combination - :rtype: set - """ - supporting_op_classes = set() - datatype_comb = (element_a, element_b, element_accumulator) - layout_comb = (layout_a, layout_b) - - for op_class in self.operations_by_opclass.keys(): - if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): - supporting_op_classes.add(op_class) - return supporting_op_classes - - def operations( - self, - op_class: cutlass_library.OpcodeClass, - element_a: cutlass_library.DataType, - element_b: cutlass_library.DataType, - element_accumulator: cutlass_library.DataType, - layout_a: cutlass_library.LayoutType, - layout_b: cutlass_library.LayoutType, - math_operation: cutlass_library.MathOperation, - ) -> KernelsForDataType: - """ - Returns whether the provided operation class supports the provided data type combination - - :param op_class: operation class to consider - :type op_class: cutlass_library.OpcodeClass - :param element_a: data type of operand A - :type element_a: cutlass_library.DataType - :param element_b: data type of operand B - :type element_b: cutlass_library.DataType - :param element_accumulator: data type of accumulator - :type element_accumulator: cutlass_library.DataType - :param layout_a: layout of operand A - :type layout_a: cutlass_library.LayoutType - :param layout_b: layout of operand B - :type layout_b: cutlass_library.LayoutType - :param math_operation: math operation to consider - :type math_operation: cutlass_cppgen.MathOperation - - :return: container of kernels by alignment supported by the provided combination of parameters - :rtype: KernelsForDataType - """ - datatype_comb = (element_a, element_b, element_accumulator) - layout_comb = (layout_a, layout_b) - if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): - raise Exception( - f"Data type layout combination {datatype_comb}, {layout_comb} " - f"is not supported by opcode class {op_class} on CC {self.cc}." - ) - return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)] - - -class OptionRegistry: - """ - Container of all architecture-specific options - - :param target_cc: compute capability of the device on which operations will be run - :type target_cc: int - """ - - def __init__(self, target_cc: int): - self.registry = {} - - if target_cc > 100 and (target_cc not in [101, 103, 120, 121]): - raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.") - - gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] - operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] - # Construct options for each CC - for kernel_cc in _generator_ccs: - self.registry[kernel_cc] = {} - for opkind in operation_kinds: - self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds) - - def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions: - return self.registry.get(cc, None)[op_kind] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py deleted file mode 100644 index 0286907040fb3ded84f989bfc9d14e740307f6a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad -from cutlass_cppgen.op.gemm import Gemm -from cutlass_cppgen.op.gemm_grouped import GroupedGemm -from cutlass_cppgen.op.op import OperationBase diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py deleted file mode 100644 index 711b27da13b54e30f8b25e839ffc4f51ed80dc5c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py +++ /dev/null @@ -1,997 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" - Ease-of-use interface for constructing, compiling, and running CONVs - - The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run - CONV2D operations in CUTLASS via Python, without specifying many configuration parameters. - Under the hood, the interface will select sensible default parameters for the many template - parameters for CUTLASS CONVs. - - Note: optimal performance is not to be expected from this interface. To achieve optimal - performance, one should specify and tune each configuration parameter. - - The simplest example of using this interface is the following: - - .. highlight:: python - .. code-block:: python - - # A, B, C, and D are torch/numpy/cupy tensor objects - plan = cutlass_cppgen.op.Conv(A, B, C, D) - plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) - - One can also use the interface by specifying data types of operands at construction - and using different tensor objects with these data types at runtime: - - .. highlight:: python - .. code-block:: python - - # The following is shorthand for: - # cutlass_cppgen.op.Conv2d(kind="fprop", - # element_A=torch.float32, element_B=torch.float32, - # element_C=torch.float32, element_D=torch.float32, - # element_accumulator=torch.float32) - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32) - - A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') - B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') - C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda') - D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda') - plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) - - A = torch.rand((32, 128), dtype=torch.float32, device='cuda') - B = torch.rand((128, 256), dtype=torch.float32, device='cuda') - C = torch.zeros((32, 256), dtype=torch.float32, device='cuda') - D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda') - plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) - - The interface additionally enables one to decouple the compilation of the underlying CUTLASS - kernel from its execution: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) - - # Do other work... - - plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) - - # Do other work... - - plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) - - Elementwise activation functions are easily fused to the GEMM via the interface: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) - plan.activation = cutlass_cppgen.epilogue.relu - - Operations can also be run asynchronously: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) - args = plan.run() - - # Do other work... - - args.sync() -""" - -from __future__ import annotations -from typing import Optional -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -from cutlass_library import ( - ConvKind, - ConvMode, - DataTypeSize, - IteratorAlgorithm, - OperationKind, - SplitKMode, - StrideSupport, -) - -import cutlass_cppgen -from cutlass_cppgen import epilogue -from cutlass_cppgen.backend import compiler -from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation -from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments -from cutlass_cppgen.backend.library import TensorDescription, TileDescription -from cutlass_cppgen.op.op import OperationBase -from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord -from cutlass_cppgen.utils import check, datatypes - - -class Conv2d(OperationBase): - """ - Constructs a ``Conv2d`` object. - - The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C, - along with the data type of output D and that used for accumulation, are bound to the ``Conv`` - object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed. - - The constructor has optional parameters for flexibly setting these parameters. The following - constructors are equivalent: - - .. highlight:: python - .. code-block:: python - - # Use F32 for A, B, C, D, and accumulation in fprop - - # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. - Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32) - - # Explicitly specify the data types to use for A, B, C, and D. - Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, - element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32) - - # Set the data types and elements from existing tensors. Note that one can use different tensors when - # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must - # have the same data type as those passed in here). - # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout - Conv2d(kind="fprop", A=A, B=B, C=C, D=D) - - # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit - # those passed in via the generic ``element`` - Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, - element=cutlass_cppgen.DataType.f32) - - The order of precedence for the setting of the data type for a given operand/output is as follows: - 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor - 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those - 3) Otherwise, use the generic values (e.g., ``element``) - - :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad) - :type kind: str - :param A: tensor representing data type of operand A - :param B: tensor representing data type of operand B - :param C: tensor representing data type of operand C - :param D: tensor representing data type of operand D - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass_cppgen.DataType - :param element_A: data type to be used for operand A - :type element_A: cutlass_cppgen.DataType - :param element_B: data type to be used for operand B - :type element_B: cutlass_cppgen.DataType - :param element_C: data type to be used for operand C - :type element_C: cutlass_cppgen.DataType - :param element_D: data type to be used for operand D - :type element_D: cutlass_cppgen.DataType - :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass_cppgen.DataType - :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 - :type cc: int - :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 - :type kernel_cc: int - """ - def __init__( - self, kind="fprop", - A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, - element=None, - element_A=None, element_B=None, element_C=None, element_D=None, - element_accumulator=None, - cc: int = None, kernel_cc: int = None - ): - super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) - # Verify the kernel cc - if self.current_cc in [90, 100, 101, 103]: - # The Conv2d kernel on Hopper (SM90) is currently unsupported - # Revert to use SM80-tagged kernels - cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") - self.specified_kernel_cc = 80 - self._reset_options(80) - - # The arch is used in testing - self.arch = self.current_cc - self.name = "conv2d" + kind - - # The convolution kind. (concept: cutlass_library.library.ConvKind) - self.conv_kind = datatypes.getattr_enum(ConvKind, kind) - - # The element types (concept: cutlass library types) of A, B, C, and D - elements = [] - layouts = [] - - # Complete the data types based on user-provided arguments - for elt, tens, name in zip([element_A, element_B, element_C, element_D], - [A, B, C, D], - ["A", "B", "C", "D"]): - if elt is not None and tens is not None: - raise Exception(f'Must not specify both element_{name} and tensor {name}') - if elt is None and tens is None and element is None: - raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') - - elt_to_set = None - lay_to_set = None - - if tens is not None: - elt_to_set, _ = datatypes.get_datatype_and_layout(tens) - else: - elt_to_set = elt if elt is not None else element - - assert elt_to_set is not None - - # Currently we only support layout TensorNHWC - lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC - elements.append(datatypes.library_type(elt_to_set)) - layouts.append(lay_to_set) - - self._element_a, self._element_b, self._element_c, self._element_d = elements - self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts - - self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta - - if element_accumulator is None: - self._element_accumulator = self._element_c - else: - self._element_accumulator = datatypes.library_type(element_accumulator) - - # Default inputs if none is supplied in run() - self.A = A - self.B = B - self.C = C - self.D = D - - self.alpha = alpha - self.beta = beta - - # We only specify the stride of the swizzling functor here - # The actual swizzling functor is determined in run based on conv_kind and stride - self._swizzling_stride = 1 - - # Arguments that will be set to default value in _reset_operations - # The default tile_description and op_class are fetched from manifest of cutlass library - self._tile_description = None - self.op_class = None - # The default identity epilogue will be created - self.epilogue_functor = None - - self._reset_operations() - - # Arguments that will be determined online based on arguments of "run" - # based on stride, input/output channels, alignment, and conv_kind - self._iterator_algorithm = None - self._stride_support = None - - def _reset_operations(self, reset_epilogue: bool = True): - # Set the default op class - datatype_comb = (self._element_a, self._element_b, self._element_accumulator) - layout_comb = (self._layout_a, self._layout_b) - - self.possible_op_classes = self.options.supporting_opclasses( - self._element_a, self._element_b, self._element_accumulator, - self._layout_a, self._layout_b, self._math_operation - ) - - if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: - self.opclass = cutlass_cppgen.OpcodeClass.TensorOp - elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: - self.opclass = cutlass_cppgen.OpcodeClass.Simt - else: - if self._math_operation is not None: - math_op_str = f' and math operation {self._math_operation}' - else: - math_op_str = '' - - raise Exception(f'No kernel configuration found for supported data type and layout ' - f'combination {datatype_comb}x{layout_comb}{math_op_str}') - - if reset_epilogue: - self._reset_epilogue_functor_activation(epilogue.identity) - - self.alignment_pref_A = min( - 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) - self.alignment_pref_B = min( - 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) - self.alignment_pref_C = min( - 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C"))) - - # - # Tile description Related - # - - @property - def tile_description(self) -> TileDescription: - """ - Returns the tile description - """ - return self._tile_description - - @tile_description.setter - def tile_description( - self, td=None): - """ - Set the tile description - - :param td: tile description - :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys - { - "threadblock_shape": [int, int, int], - "warp_count": [int, int, int], - "stages": int, - "instruction_shape": [int, int, int] (optional), - "cluster_shape": [int, int, int] (optional) - } - """ - if td is None: - return - if isinstance(td, dict): - if self._tile_description is None: - op = self.possible_operations.default_operation(self._math_operation) - self._tile_description = datatypes.td_from_profiler_op(op) - if "cluster_shape" in td.keys(): - if td["cluster_shape"] != [1, 1, 1]: - cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") - td["cluster_shape"] = [1, 1, 1] - td = self._tile_description.clone_and_update(td) - - valid, msg = self._valid_tile_description(td) - if valid: - self._tile_description = td - else: - raise Exception(msg) - - def _valid_tile_description(self, td: TileDescription) -> tuple: - """ - Checks whether the provided tile description is valid for the given compute capability. At present, - this checks the following: - - - Does the tile description use a number of stages supported by the compute capability in question? - - Does the tile size requested fit within shared memory? - - Are cluster dimensions outside the valid range requested for a given architecture (e.g., - more non-unit cluster dimensions for pre-SM90 architectures)? - - Is the kernel schedule being used supported on the architecture in question? - - :param td: tile description to validate - :type td: cutlass_cppgen.backend.TileDescription - :return: tuple in which the first element is a bool indicating that the tile description is valid - and the second element is a string providing an optional error message. - :rtype: tuple - """ - valid, msg = check.valid_stage_count(self.cc, self.current_cc, td) - if not valid: - return (valid, msg) - - valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) - if not valid: - return (valid, msg) - - return valid, msg - - def tile_descriptions(self) -> list: - """ - Returns a list of valid tile descriptions for the operations - - :returns: list of valid tile descriptions for the operations - :rtype: list - """ - descriptions = [] - description_str = [] - for op in self.possible_operations.all_operations: - td = datatypes.td_from_profiler_op(op) - - if self._math_operation is not None: - if td.math_instruction.math_operation != self._math_operation: - continue - - if str(td) not in description_str: - description_str.append(str(td)) - descriptions.append(td) - return descriptions - - # - # Swizzling functor Related - # - - @property - def swizzling_stride(self): - """ - Returns the stride of swizzling currently being used by the Conv2d - - :return: swizzing stride - """ - return self._swizzling_stride - - @swizzling_stride.setter - def swizzling_stride(self, stride: int): - """ - Sets the swizzling functor to the type specified by `swizzling_functor` - """ - if not isinstance(stride, int): - raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}") - self._swizzling_stride = stride - - def _propose_swizzling_functor(self, stride): - """ - Automatically propose the swizzling functor based on the stride - """ - if self.conv_kind == ConvKind.Dgrad: - if stride[0] != 1 or stride[1] != 1: - return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") - - return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}") - - # - # Iterator Algorithm Related - # - - @property - def iterator_algorithm(self) -> IteratorAlgorithm: - """ - Returns the iterator algorithm - """ - return self._iterator_algorithm - - @iterator_algorithm.setter - def iterator_algorithm(self, alg: str): - """ - Sets the iterator algorithm - - :param alg: The iterator algorithm - :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels" - """ - iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg) - - # Check if the iterator algorithm is valid - if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop: - raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.") - - self._iterator_algorithm = iterator_alg - - def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm: - """ - Propose a valid iterator algorithm based on problem size and alignment - """ - if self.conv_kind == ConvKind.Fprop: - # Check whether the fixed channel is applicable - if problem_size.C == alignment_a: - return IteratorAlgorithm.FixedChannels - elif (problem_size.C % alignment_a == 0 and - problem_size.R <= 32 and problem_size.S <= 32): - return IteratorAlgorithm.Optimized - else: - return IteratorAlgorithm.Analytic - elif self.conv_kind == ConvKind.Dgrad: - if (problem_size.K % alignment_a == 0 and - problem_size.R <= 32 and problem_size.S <= 32 and - problem_size.C % alignment_b == 0): - return IteratorAlgorithm.Optimized - else: - return IteratorAlgorithm.Analytic - elif self.conv_kind == ConvKind.Wgrad: - if (problem_size.K % alignment_a == 0 and - problem_size.C % alignment_b == 0): - return IteratorAlgorithm.Optimized - else: - return IteratorAlgorithm.Analytic - - def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool: - """ - Validate whether the user provide iterator algorithm works for the given problem size - """ - if self.conv_kind == ConvKind.Fprop: - if iterator_algorithm == IteratorAlgorithm.FixedChannels: - return problem_size.C == alignment_a - elif iterator_algorithm == IteratorAlgorithm.Optimized: - return (problem_size.C % alignment_a == 0 and - problem_size.R <= 32 and problem_size.S <= 32) - elif iterator_algorithm == IteratorAlgorithm.FewChannels: - return problem_size.C % alignment_a == 0 - elif self.conv_kind == ConvKind.Dgrad: - if iterator_algorithm == IteratorAlgorithm.Optimized: - return (problem_size.K % alignment_a == 0 and - problem_size.R <= 32 and problem_size.S <= 32 and - problem_size.C % alignment_b == 0) - elif self.conv_kind == ConvKind.Wgrad: - if iterator_algorithm == IteratorAlgorithm.Optimized: - return (problem_size.K % alignment_a == 0 and - problem_size.C % alignment_b == 0) - - return True - - # - # Stride Support Related - # - - def _propose_stride_support(self, stride): - if self.conv_kind == ConvKind.Dgrad: - if stride[0] == 1 and stride[1] == 1: - return StrideSupport.Unity - - return StrideSupport.Strided - - # - # Construct and Compilation - # - - def construct( - self, tile_description: TileDescription = None, - alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, - iterator_algorithm: IteratorAlgorithm = None, - stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, - epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation: - """ - Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current - kernel specification of the ``Conv2d`` object. - - :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass_cppgen.backend.TileDescription - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - :param alignment_C: alignment of operand C - :type alignment_C: int - :param iterator_algorithm: the iterator algorithm used - :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm - :param stride_support: the stride support of dgrad - :type stride_support: cutlass_library.library.StrideSupport - :param swizzling_functor: the swizzling functor - :type swizzling_functor: cutlass_cppgen.swizzle - :param epilogue_functor: the epilogue functor - - :return: operation that was constructed - :rtype: cutlass_cppgen.backend.Conv2dOperation - """ - # Get alignment - alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) - alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B) - alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C) - - tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) - tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) - tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) - - if tile_description is None: - if self.tile_description is not None: - tile_description = self.tile_description - else: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] - tile_description = datatypes.td_from_profiler_op(op) - else: - valid, err_str = self._valid_tile_description(tile_description) - if not valid: - raise Exception(f"Invalid tile description. {err_str}") - self.tile_description = tile_description - - if iterator_algorithm is None: - # If the iterator algorithm is already set - if self.iterator_algorithm is not None: - iterator_algorithm = self.iterator_algorithm - else: - # Otherwise, we conservatively use the analytic iterator for correctness - iterator_algorithm = IteratorAlgorithm.Analytic - - if stride_support is None: - # If the stride support is already set - if self._stride_support is not None: - stride_support = self._stride_support - else: - # Otherwise, we assume strided - stride_support = StrideSupport.Strided - - if swizzling_functor is None: - # If the swizzling functor is already set - swizzling_functor = self._propose_swizzling_functor(stride=(2, 2)) - - if epilogue_functor is None: - if self.epilogue_functor is not None: - epilogue_functor = self.epilogue_functor - else: - epilogue_functor = self._create_epilogue_functor_activation(self._activation) - - # Reset the alignment of the epilogue functor - epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor) - - operation = Conv2dOperation( - conv_kind=self.conv_kind, - iterator_algorithm=iterator_algorithm, - arch=self.current_cc, - tile_description=tile_description, - A=tensor_A, B=tensor_B, C=tensor_C, - stride_support=stride_support, - epilogue_functor=epilogue_functor, - swizzling_functor=swizzling_functor, - ) - - return operation - - def compile(self, tile_description: TileDescription = None, - alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, - iterator_algorithm: IteratorAlgorithm = None, - stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, - epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation: - """ - Emits and compiles the kernel currently specified. If ``tile_description`` and any - of the ``alignment`` parameters are set, the kernel will be chosen using this - tile description and alignments. Otherwise, a default tile description and alignment - will be used. - - ::param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass_cppgen.backend.TileDescription - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - :param alignment_C: alignment of operand C - :type alignment_C: int - :param iterator_algorithm: the iterator algorithm used - :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm - :param stride_support: the stride support of dgrad - :type stride_support: cutlass_library.library.StrideSupport - :param swizzling_functor: the swizzling functor - :type swizzling_functor: cutlass_cppgen.swizzle - :param epilogue_functor: the epilogue functor - - :return: operation that was compiled - :rtype: cutlass_cppgen.backend.Conv2dOperation - """ - - self.operation = self.construct( - tile_description, alignment_A, alignment_B, alignment_C, - iterator_algorithm, stride_support, swizzling_functor, epilogue_functor) - - if print_module: - print(self.operation.rt_module.emit()) - - compiler.add_module([self.operation,]) - return self.operation - - # - # Run Related - # - - def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): - """ - Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception - is raised if it does not. - - :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type tensor: numpy/cupy/torch array/tensor object - :param ref_dtype: data type for the tensor that this object was initialized to - :param name: identifier of the tensor to verify. Used in raising exceptions - :type name: str - """ - dtype, _ = datatypes.get_datatype_and_layout(tensor) - if dtype != ref_type: - raise Exception(f'Tensor {name} with type and layout {dtype} ' - f'does not match the expected type of {ref_type}.') - - def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation): - if self.conv_kind == ConvKind.Fprop: - input = A - weight = B - output = C - output_tensor = "C" - elif self.conv_kind == ConvKind.Dgrad: - output = A - weight = B - input = C - output_tensor = "A" - elif self.conv_kind == ConvKind.Wgrad: - output = A - input = B - weight = C - output_tensor = "A" - else: - raise Exception(f"Convolution kind {self.conv_kind} is not supported") - - N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV") - K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV") - _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV") - - problem_size = Conv2DProblemSize( - N_, H_, W_, C_, - K_, R_, S_, C_, - padding[0], padding[1], - stride[0], stride[1], - dilation[0], dilation[1], - ConvMode.CrossCorrelation, - 1, 1 - ) - - if P_ != problem_size.P or Q_ != problem_size.Q: - raise Exception( - f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})") - - return problem_size - - def run(self, A=None, B=None, C=None, D=None, - stride=(1, 1), padding=(0, 0), dilation=(1, 1), - alpha=None, beta=None, - split_k=("serial", 1), sync: bool = True, - print_module: bool = False, - stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: - """ - Runs the kernel currently specified. If it has not already been, the kernel is emitted and - compiled. Tensors holding operands and outputs of the kernel are sourced either from the - ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` - parameters provided in the call, or from those - passed in on the construction of this object -- one of the two must be specified. - - By default, this call returns only once the kernel has completed. To launch the kernel - and immediately return, set ``sync=False``. In this case, it is the responsibility of the - caller to syncrhonize the results of the kernel before attempting to access outputs - by calling ``sync()`` on the arguments returned from this call. - - :param A: tensor representing data type and layout of operand A - :param B: tensor representing data type and layout of operand B - :param C: tensor representing data type and layout of operand C - :param D: tensor representing data type and layout of operand D - :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1) - :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0) - :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1) - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param split_k: a tuple (split_k_mode, split_k_slices) - :param sync: whether the call should wait for the kernel to complete before returning - :type sync: bool - :param print_module: whether to print the emitted C++ code - :type print_module: bool - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - - :return: arguments passed in to the kernel - :rtype: cutlass_cppgen.backend.Conv2dArguments - """ - if not stream: - stream = cuda.CUstream(0) - super().run_setup() - - A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") - B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") - C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") - D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") - alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") - beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") - - # handle the case when there is no C - if C is None: - if beta != 0: - raise Exception(f"With beta {beta} != 0, C has to be provided.") - else: - C = D - - # Construct problem size based on input - # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching - problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation) - - # Propose stride support based on input - stride_support = self._propose_stride_support(stride) - - # Propose swizzling functor - swizzling_functor = self._propose_swizzling_functor(stride) - - shape_a = datatypes.get_tensor_shape(A, op="CONV") - shape_b = datatypes.get_tensor_shape(B, op="CONV") - shape_c = datatypes.get_tensor_shape(C, op="CONV") - - # Get the alignment - alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A") - alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B") - alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C") - - alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) - alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) - alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C) - - # Propose iterator algorithm based on input - if self._iterator_algorithm is None: - # Propose a default iterator algorithm based on the problem size - iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b) - else: - if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)): - iterator_algorithm = self._iterator_algorithm - else: - raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.") - - epilogue_args = [alpha, beta] - - if hasattr(self, "_activation_args"): - if isinstance(self._activation_args, list): - epilogue_args += self._activation_args - else: - epilogue_args.append(self._activation_args) - - if split_k[0] == "parallel" and split_k[1] > 1: - epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity) - else: - epilogue_functor = self.epilogue_functor - - # The alignment is determined by the iterator function (I believe) - self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, - alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support, - swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module) - - # Create reduction operation for parallel split-k - if split_k[0] == "parallel" and split_k[1] > 1: - epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor) - self.reduction_operation = ReductionOperation( - shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C, - element_accumulator=self._element_accumulator, - element_compute=self._element_accumulator, - epilogue_functor=epilogue_functor_reduction, - count=alignment_c - ) - if print_module: - print(self.reduction_operation.rt_module.emit()) - compiler.add_module([self.reduction_operation,]) - - arguments = Conv2dArguments( - operation=self.operation, problem_size=problem_size, - A=A, B=B, C=C, D=D, - output_op=self.operation.epilogue_type(*epilogue_args), - split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]), - split_k_slices=split_k[1], - stream=stream - ) - - self.operation.run(arguments) - - if split_k[0] == "parallel" and split_k[1] > 1: - implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind) - reduction_arguments = ReductionArguments( - self.reduction_operation, - problem_size=[implicit_gemm_size.m, implicit_gemm_size.n], - partitions=split_k[1], - workspace=arguments.ptr_D, - destination=D, - source=C, - output_op=self.reduction_operation.epilogue_type(*epilogue_args), - stream=stream - ) - self.reduction_operation.run(reduction_arguments) - - if sync: - if split_k[0] == "parallel" and split_k[1] > 1: - reduction_arguments.sync() - - # Free memory allocated by args because we are not - # calling `arguments.sync()` in this case (which will free memory) - arguments.free() - else: - arguments.sync() - - return arguments - - # - # Helper functions - # - @staticmethod - def output_size(input_size, weight_size, padding, stride, dilation): - problem_size = Conv2DProblemSize( - *input_size, - *weight_size, - padding[0], padding[1], - stride[0], stride[1], - dilation[0], dilation[1], - ConvMode.CrossCorrelation, - 1, 1 - ) - return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K) - - -# -# Easy to use interfaces for fprop, wgrad, and dgrad -# - -class Conv2dFprop(Conv2d): - def __init__( - self, - input=None, weight=None, C=None, output=None, alpha=1, beta=0, - element=None, - element_input=None, element_weight=None, element_C=None, element_output=None, - element_accumulator=None, - cc: int = None, kernel_cc: int = None): - A, B, D = input, weight, output - element_A, element_B, element_D = element_input, element_weight, element_output - super().__init__( - "fprop", A, B, C, D, alpha, beta, element, - element_A, element_B, element_C, element_D, - element_accumulator, cc, kernel_cc) - - def run( - self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, - stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), - sync: bool = True, print_module: bool = False, - stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: - - if not stream: - stream = cuda.CUstream(0) - - A, B, D = input, weight, output - return super().run( - A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) - - -class Conv2dDgrad(Conv2d): - def __init__( - self, - grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0, - element=None, - element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None, - element_accumulator=None, - cc: int = None, kernel_cc: int = None): - A, B, D = grad_output, weight, grad_input - element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input - super().__init__( - "dgrad", A, B, C, D, alpha, beta, element, - element_A, element_B, element_C, element_D, - element_accumulator, cc, kernel_cc) - - def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, - stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), - sync: bool = True, print_module: bool = False, - stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: - # - if not stream: - stream = cuda.CUstream(0) - - A, B, D = grad_output, weight, grad_input - return super().run( - A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) - - -class Conv2dWgrad(Conv2d): - def __init__( - self, - grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0, - element=None, - element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None, - element_accumulator=None, - cc: int = None, kernel_cc: int = None): - A, B, D = grad_output, input, grad_weight - element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight - super().__init__( - "wgrad", A, B, C, D, alpha, beta, element, - element_A, element_B, element_C, element_D, - element_accumulator, cc, kernel_cc) - - def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, - stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), - sync: bool = True, print_module: bool = False, - stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: - if not stream: - stream = cuda.CUstream(0) - - A, B, D = grad_output, input, grad_weight - return super().run( - A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py deleted file mode 100644 index a6f9b1ab43a1c45d0024e99e50e45813ba18866e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py +++ /dev/null @@ -1,725 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" - Ease-of-use interface for constructing, compiling, and running GEMMs. - - The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run - GEMM operations in CUTLASS via Python, without specifying many configuration parameters. - Under the hood, the interface will select sensible default parameters for the many template - parameters for CUTLASS GEMMs. - - Note: optimal performance is not to be expected from this interface. To achieve optimal - performance, one should specify and tune each configuration parameter. - - The simplest example of using this interface is the following: - - .. highlight:: python - .. code-block:: python - - # A, B, C, and D are torch/numpy/cupy tensor objects - plan = cutlass_cppgen.op.Gemm(A, B, C, D) - plan.run() - - - One can also use the interface by specifying data types of operands at construction - and using different tensor objects with these data types at runtime: - - .. highlight:: python - .. code-block:: python - - # The following is shorthand for: - # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32, - # element_C=torch.float32, element_D=torch.float32, - # element_accumulator=torch.float32, - # layout=cutlass_cppgen.LayoutType.RowMajor) - plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) - - A0 = torch.rand((128, 256), device='cuda') - B0 = torch.rand((256, 64), device='cuda') - C0 = torch.zeros((128, 64), device='cuda') - D0 = torch.zeros((128, 64), device.'cuda') - plan.run(A0, B0, C0, D0) - - A = torch.rand((32, 128), device='cuda') - B = torch.rand((128, 256), device='cuda') - C = torch.zeros((32, 256), device='cuda') - D = torch.zeros((32, 256), device.'cuda') - plan.run(A1, B1, C1, D1) - - The interface additionally enables one to decouple the compilation of the underlying CUTLASS - kernel from its execution: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) - plan.compile() - - # Do other work... - - plan.run(A0, B0, C0, D0) - - # Do other work... - - plan.run(A1, B1, C1, D1) - - Elementwise activation functions are easily fused to the GEMM via the interface: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) - plan.activation = cutlass_cppgen.epilogue.relu - - Operations can also be run asynchronously: - - .. highlight:: python - .. code-block:: python - - plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) - args = plan.run() - - # Do other work... - - args.sync() -""" -from __future__ import annotations -from typing import Optional -from math import prod - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -from cutlass_library import ( - DataType, - DataTypeSize, - GemmUniversalMode, - KernelScheduleSuffixes, -) - -import cutlass_cppgen -from cutlass_cppgen import epilogue, swizzle -from cutlass_cppgen.backend import compiler -from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor -from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal -from cutlass_cppgen.backend.library import TensorDescription, TileDescription -from cutlass_cppgen.op.op import OperationBase -from cutlass_cppgen.shape import GemmCoord -from cutlass_cppgen.utils import check, datatypes - - -class Gemm(OperationBase): - """ - Constructs a ``Gemm`` object. - - The data types and layouts of operands A, B, and C, along with the data type of output D - and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- - these are not to be changed after a ``Gemm`` has been constructed. - - The constructor has optional parameters for flexibly setting these parameters. The following - constructors are equivalent: - - .. highlight:: python - .. code-block:: python - - # Use F32 for A, B, C, D, and accumulation. All operands are row major. - - # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts - # for operands to the same values. - Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) - - # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. - Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32, - element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) - - # Set the data types and elements from existing tensors. Note that one can use different tensors when - # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must - # have the same data type and layout as those passed in here). - # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 - Gemm(A=A, B=B, C=C, D=D) - - # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is - # the same as that for D, at present) - Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor, - layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor) - - # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types - # and layouts will inherit those passed in via the generic ``element`` and ``layout`` - Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor, - element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) - - The order of precedence for the setting of the data type and layout for a given operand/output is as follows: - 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor - 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those - 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) - - :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 - :type cc: int - :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 - :type kernel_cc: int - :param A: tensor representing data type and layout of operand A - :param B: tensor representing data type and layout of operand B - :param C: tensor representing data type and layout of operand C - :param D: tensor representing data type and layout of operand D - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass_cppgen.DataType - :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass_cppgen.DataType - :param layout: generic layout type to be used for operands A, B, C, and D - :type layout: cutlass_cppgen.LayoutType - :param element_A: data type to be used for operand A - :type element_A: cutlass_cppgen.DataType - :param element_B: data type to be used for operand B - :type element_B: cutlass_cppgen.DataType - :param element_C: data type to be used for operand C - :type element_C: cutlass_cppgen.DataType - :param element_D: data type to be used for operand D - :type element_D: cutlass_cppgen.DataType - :param layout_A: layout of operand A - :type layout_A: cutlass_cppgen.LayoutType - :param layout_B: layout of operand B - :type layout_B: cutlass_cppgen.LayoutType - :param layout_C: layout of operand C - :type layout_C: cutlass_cppgen.LayoutType - :param layout_D: layout of operand D - :type layout_D: cutlass_cppgen.LayoutType - """ - - def __init__( - self, A=None, B=None, C=None, D=None, - alpha=1.0, beta=0.0, element_accumulator=None, - element=None, layout=None, - element_A=None, element_B=None, element_C=None, element_D=None, - layout_A=None, layout_B=None, layout_C=None, - cc: int = None, kernel_cc: int = None - ): - super().__init__(cc=cc, kernel_cc=kernel_cc) - self.name = "gemm" - self.compiled = False - - elements = [] - layouts = [] - - # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): - # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` - for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], - [layout_A, layout_B, layout_C, layout_C], - [A, B, C, D], - ["A", "B", "C", "D"]): - if elt is not None and tens is not None: - raise Exception(f'Must not specify both element_{name} and tensor {name}') - if lay is not None and tens is not None: - raise Exception(f'Must not specify both layout_{name} and tensor {name}') - if elt is None and tens is None and element is None: - raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') - if lay is None and tens is None and layout is None: - raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') - - elt_to_set = None - lay_to_set = None - if tens is not None: - elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) - else: - elt_to_set = elt if elt is not None else element - lay_to_set = lay if lay is not None else layout - - elements.append(datatypes.library_type(elt_to_set)) - layouts.append(lay_to_set) - - self._element_a, self._element_b, self._element_c, self._element_d = elements - self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts - - if element_accumulator is None: - self._element_accumulator = self._element_c - else: - self._element_accumulator = datatypes.library_type(element_accumulator) - - self.A = A - self.B = B - self.C = C - self.D = D - - self.alpha = alpha - self.beta = beta - - self.epilogue_functor = None - self.op_class = None - self._tile_description = None - - self._reset_operations() - - self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1 - - def _reset_operations(self, reset_epilogue: bool = True): - # Set the default op class - datatype_comb = (self._element_a, self._element_b, self._element_accumulator) - layout_comb = (self._layout_a, self._layout_b) - - self.possible_op_classes = self.options.supporting_opclasses( - self._element_a, self._element_b, self._element_accumulator, - self._layout_a, self._layout_b, self._math_operation) - - if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: - self.opclass = cutlass_cppgen.OpcodeClass.TensorOp - elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: - self.opclass = cutlass_cppgen.OpcodeClass.Simt - else: - if self._math_operation is not None: - math_op_str = f' and math operation {self._math_operation}' - else: - math_op_str = '' - - raise Exception(f'No kernel configuration found for supported data type and layout ' - f'combination {datatype_comb}x{layout_comb}{math_op_str}') - - if reset_epilogue: - self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity) - - @property - def swizzling_functor(self): - """ - Returns the type of the swizzling functor currently being used by the GEMM - - :return: swizzing functor type - """ - return self._swizzling_functor - - @swizzling_functor.setter - def swizzling_functor(self, swizzling_functor): - """ - Sets the swizzling functor to the type specified by `swizzling_functor` - """ - if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK: - if self.op_class == cutlass_cppgen.OpcodeClass.Simt: - raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') - - if self.current_cc in [90, 100, 101, 103]: - raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') - self._swizzling_functor = swizzling_functor - - # - # Tile description Related - # - - @property - def tile_description(self) -> TileDescription: - """ - Returns the tile description - """ - return self._tile_description - - @tile_description.setter - def tile_description( - self, td=None): - """ - Set the tile description - - :param td: tile description - :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys - { - "threadblock_shape": [int, int, int], - "warp_count": [int, int, int], - "stages": int, - "instruction_shape": [int, int, int] (optional), - "cluster_shape": [int, int, int] (optional) - } - """ - if td is None: - return - if isinstance(td, dict): - if self._tile_description is None: - op = self.possible_operations.default_operation(self._math_operation) - self._tile_description = datatypes.td_from_profiler_op(op) - td = self._tile_description.clone_and_update(td) - - valid, msg = self._valid_tile_description(td) - if valid: - self._tile_description = td - else: - raise Exception(msg) - - def _valid_tile_description(self, td: TileDescription) -> tuple: - """ - Checks whether the provided tile description is valid for the given compute capability. At present, - this checks the following: - - - Does the tile description use a number of stages supported by the compute capability in question? - - Does the tile size requested fit within shared memory? - - Are cluster dimensions outside the valid range requested for a given architecture (e.g., - more non-unit cluster dimensions for pre-SM90 architectures)? - - Is the kernel schedule being used supported on the architecture in question? - - :param td: tile description to validate - :type td: cutlass_cppgen.backend.TileDescription - :return: tuple in which the first element is a bool indicating that the tile description is valid - and the second element is a string providing an optional error message. - :rtype: tuple - """ - valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d) - if not valid: - return (valid, msg) - - valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) - if not valid: - return (valid, msg) - - valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) - - if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: - valid = False - msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" - - return valid, msg - - def tile_descriptions(self) -> list: - """ - Returns a list of valid tile descriptions for the operations - - :returns: list of valid tile descriptions for the operations - :rtype: list - """ - tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] - if self._math_operation is not None: - tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation] - return tds - - def construct( - self, tile_description: TileDescription = None, - alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: - """ - Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current - kernel specification of the ``Gemm`` object. - - :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass_cppgen.backend.TileDescription - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - :param alignment_C: alignment of operand C - :type alignment_C: int - - :return: operation that was constructed - :rtype: cutlass_cppgen.backend.GemmOperationUniversal - """ - alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) - alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) - alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) - alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) - - tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) - tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) - - if alignment_C is None: - alignment_C = max(self.possible_operations.alignments("C")) - if self._element_c != DataType.void: - alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C) - - if tile_description is None: - if self._tile_description is None: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] - tile_description = datatypes.td_from_profiler_op(op) - - # The selected op may have lower alignment than that determined above, so we must - # reset alignment here. - alignment_C = op.C.alignment - else: - tile_description = self._tile_description - else: - valid, err_str = self._valid_tile_description(tile_description) - if not valid: - raise Exception(f"Invalid tile description. {err_str}") - self._tile_description = tile_description - - tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) - self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) - - operation = GemmOperationUniversal( - arch=self.current_cc, - tile_description=tile_description, - A=tensor_A, B=tensor_B, C=tensor_C, - epilogue_functor=self.epilogue_functor, - swizzling_functor=self._swizzling_functor, - ) - - return operation - - def compile(self, tile_description: TileDescription = None, - alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, - print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal: - """ - Emits and compiles the kernel currently specified. If ``tile_description`` and any - of the ``alignment`` parameters are set, the kernel will be chosen using this - tile description and alignments. Otherwise, a default tile description and alignment - will be used. - - :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass_cppgen.backend.TileDescription - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - :param alignment_C: alignment of operand C - :type alignment_C: int - :param print_module: whether to print the emitted C++ code - :type print_module: bool - - :return: operation that was compiled - :rtype: cutlass_cppgen.backend.GemmOperationUniversal - """ - self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) - - if print_module: - print(self.operation.rt_module.emit()) - - compiler.add_module([self.operation,]) - return self.operation - - def _verify_rank(self, tensor): - """ - Verifies that ``tensor`` has rank greater than 1 - - :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type tensor: numpy/cupy/torch array/tensor object - """ - if len(tensor.shape) < 2: - raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}") - - def _get_batch_count(self, A, B, C, D) -> int: - """ - Returns the batch count specified by the tensors A, B, C, and D and verifies that these - tensors match in batch size. Presence of a batch dimension is detected by one of the - tensors being rank 3. If a batch dimension is present, it must be present in one of - operands A, B, or C (but need not be in all), and must be present in D. - - :param A: tensor A - :type A: numpy/cupy/torch array/tensor object - :param B: tensor B - :type B: numpy/cupy/torch array/tensor object - :param C: tensor C - :type C: numpy/cupy/torch array/tensor object - :param D: tensor D - :type D: numpy/cupy/torch array/tensor object - - :return: tuple of batch count dimensions - :rtype: tuple - """ - A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1 - B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1 - - if 1 not in [A_batch, B_batch]: - if A_batch != B_batch: - raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}") - return max(A_batch, B_batch) - - def _get_batch_stride(self, tensor) -> int: - """ - Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0. - - :param tensor: tensor object to process - :type tensor: numpy/cupy/torch array/tensor object - - :return: stride between each matrix in the batch - :rtype: int - """ - if tensor is not None and len(tensor.shape) > 2: - return tensor.shape[-2] * tensor.shape[-1] - else: - return 0 - - def _get_problem_args(self, A, B, C, D) -> tuple: - """ - Returns the problem size and GEMM universal mode to use for the - given operands. - - :param A: tensor A - :type A: numpy/cupy/torch array/tensor object - :param B: tensor B - :type B: numpy/cupy/torch array/tensor object - :param C: tensor C - :type C: numpy/cupy/torch array/tensor object - :param D: tensor D - :type D: numpy/cupy/torch array/tensor object - - :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int) - :rtype: tuple - """ - M, K = A.shape[-2:] - N = B.shape[-1] - mode = GemmUniversalMode.Gemm - - batch_count = self._get_batch_count(A, B, C, D) - returned_batch_count = batch_count - - # If we are running a batched GEMM in which there is a nonzero batch stride - # only for A, then we can fold the batched dimension of A into the M dimension - # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A - # and C are row major. A similar operation can be performed if only B has a nonzero - # batch dimension - if batch_count > 1: - A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor - B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor - C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor - - # Consider a Tensor to be batched if its rank is > 2 and - # the product of the modes beyond rank 2 equals our pre-determined batch size. - batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count) - - if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row: - M *= batch_count - returned_batch_count = 1 - elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row: - N *= batch_count - returned_batch_count = 1 - else: - mode = GemmUniversalMode.Batched - - return GemmCoord(M, N, K), mode, returned_batch_count - - def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): - """ - Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception - is raised if it does not. - - :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type tensor: numpy/cupy/torch array/tensor object - :param ref_dtype: data type for the tensor that this object was initialized to - :param ref_layout: layout for the tensor that this object was initialized to - :param name: identifier of the tensor to verify. Used in raising exceptions - :type name: str - """ - dtype, layout = datatypes.get_datatype_and_layout(tensor) - if dtype != ref_type or layout != ref_layout: - try: - # Attempt to transpose the tensor to fit the desired layout - tensor = tensor.transpose(-1, -2) - except: - raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' - f'does not match the expected type and ' - f'layout of ({ref_type}, {ref_layout}) and transpose failed.') - - def run(self, A=None, B=None, C=None, D=None, - alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, - stream: Optional[cuda.CUstream] = None) -> GemmArguments: - """ - Runs the kernel currently specified. If it has not already been, the kernel is emitted and - compiled. Tensors holding operands and outputs of the kernel are sourced either from the - ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` - parameters provided in this call, or from those - passed in on the construction of this object -- one of the two must be specified. - - By default, this call returns only once the kernel has completed. To launch the kernel - and immediately return, set ``sync=False``. In this case, it is the responsibility of the - caller to syncrhonize the results of the kernel before attempting to access outputs - by calling ``sync()`` on the arguments returned from this call. - - :param A: tensor representing data type and layout of operand A - :param B: tensor representing data type and layout of operand B - :param C: tensor representing data type and layout of operand C - :param D: tensor representing data type and layout of operand D - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param sync: whether the call should wait for the kernel to complete before returning - :type sync: bool - :param print_module: whether to print the emitted C++ code - :type print_module: bool - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - - :return: arguments passed in to the kernel - :rtype: cutlass_cppgen.backend.GemmArguments - """ - if not stream: - stream = cuda.CUstream(0) - super().run_setup() - A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") - B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") - C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") - D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") - alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") - beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") - - is_void_c = self._element_c == DataType.void - - self._verify_rank(A) - self._verify_rank(B) - if not is_void_c: - self._verify_rank(C) - self._verify_rank(D) - - alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") - alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") - - # Set C alignment based on D.shape so as to correctly get an alignment with void-C - # kernels, for which `C` is None. - alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C") - self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b, - alignment_C=alignment_c, print_module=print_module) - - problem_size, mode, batch_count = self._get_problem_args(A, B, C, D) - - if mode == GemmUniversalMode.Gemm or batch_count == 1: - kwargs = {'split_k_slices': 1} - else: - kwargs = { - 'batch': batch_count, - 'batch_strides': { - 'A': self._get_batch_stride(A), - 'B': self._get_batch_stride(B), - 'C': self._get_batch_stride(C), - 'D': self._get_batch_stride(D) - } - } - - kwargs['stream'] = stream - - if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): - output_op = self.operation.epilogue_type(visitor_args) - else: - output_op = self.operation.epilogue_type(alpha, beta) - - arguments = GemmArguments( - operation=self.operation, problem_size=problem_size, - A=A, B=B, C=C, D=D, - output_op=output_op, - gemm_mode=mode, - **kwargs - ) - - self.operation.run(arguments) - - if sync: - arguments.sync() - - return arguments diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py deleted file mode 100644 index 59f90535c29a816541bc1a2155fea35afd1c94fd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py +++ /dev/null @@ -1,269 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" - Ease-of-use interface for constructing, compiling, and running GEMMs. - - The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run - grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. - Under the hood, the interface will select sensible default parameters for the many template - parameters for CUTLASS grouped GEMMs. - - Note: optimal performance is not to be expected from this interface. To achieve optimal - performance, one should specify and tune each configuration parameter. - - The simplest example of using this interface is the following: - - .. highlight:: python - .. code-block:: python - - # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects - plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) - plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) -""" -from __future__ import annotations -from typing import Optional -from cutlass_library import DataTypeSize - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -from cutlass_cppgen.backend.gemm_operation import ( - GemmGroupedArguments, - GemmOperationGrouped, -) -from cutlass_cppgen.backend.library import ( - SchedulerMode, - TensorDescription, - TileDescription, -) -from cutlass_cppgen.op.gemm import Gemm -from cutlass_cppgen.shape import GemmCoord -from cutlass_cppgen.utils import check, datatypes - - -class GroupedGemm(Gemm): - """ - Constructs a ``GroupedGemm`` object. - - The data types and layouts of operands A, B, and C, along with the data type of output D - and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime -- - these are not to be changed after a ``GroupedGemm`` has been constructed. - - The constructor has optional parameters for flexibly setting these parameters. Please see the constructor - for ``Gemm`` for examples of these. - - :param cc: compute capability of device to generate kernels for - :type cc: int - :param A: tensor representing data type and layout of operands A - :param B: tensor representing data type and layout of operands B - :param C: tensor representing data type and layout of operands C - :param D: tensor representing data type and layout of operands D - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass_cppgen.DataType - :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass_cppgen.DataType - :param layout: generic layout type to be used for operands A, B, C, and D - :type layout: cutlass_cppgen.LayoutType - :param element_A: data type to be used for operand A - :type element_A: cutlass_cppgen.DataType - :param element_B: data type to be used for operand B - :type element_B: cutlass_cppgen.DataType - :param element_C: data type to be used for operand C - :type element_C: cutlass_cppgen.DataType - :param element_D: data type to be used for operand D - :type element_D: cutlass_cppgen.DataType - :type layout_A: layout of operand A - :param layout_A: cutlass_cppgen.LayoutType - :type layout_B: layout of operand B - :param layout_B: cutlass_cppgen.LayoutType - :type layout_C: layout of operand C - :param layout_C: cutlass_cppgen.LayoutType - :type layout_D: layout of operand D - :param layout_D: cutlass_cppgen.LayoutType - """ - - def __init__( - self, A=None, B=None, C=None, D=None, - alpha=1.0, beta=0.0, element_accumulator=None, - element=None, layout=None, - element_A=None, element_B=None, element_C=None, element_D=None, - layout_A=None, layout_B=None, layout_C=None, - cc: int = None, - ): - super().__init__( - A=A, B=B, C=C, D=D, - alpha=alpha, beta=beta, - element_accumulator=element_accumulator, - element=element, layout=layout, - element_A=element_A, element_B=element_B, - element_C=element_C, element_D=element_D, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - cc=cc - ) - - # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 - if self.current_cc in [90, 100, 101, 103]: - self._reset_options(80) - self._reset_operations(reset_epilogue=False) - - self.name = "grouped_gemm" - - @Gemm.swizzling_functor.setter - def swizzling_functor(self, swizzling_functor): - """ - Sets the swizzling functor to the type specified by `swizzling_functor` - """ - raise Exception('Grouped GEMM does not currently support different swizzling functors') - - def construct(self, tile_description: TileDescription = None, - alignment_A: int = None, - alignment_B: int = None, - alignment_C: int = None) -> GemmOperationGrouped: - """ - Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current - kernel specification of the ``Gemm`` object. - - :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass_cppgen.backend.TileDescription - :param alignment_A: alignment of operand A - :type alignment_A: int - :param alignment_B: alignment of operand B - :type alignment_B: int - :param alignment_C: alignment of operand C - :type alignment_C: int - - :return: operation that was constructed - :rtype: cutlass_cppgen.backend.GemmOperationGrouped - """ - alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) - alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) - alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C"))) - - self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) - - tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) - tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) - tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) - - if tile_description is None: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] - tile_description = datatypes.td_from_profiler_op(op) - else: - valid, err_str = self._valid_tile_description(tile_description) - if not valid: - raise Exception(f"Invalid tile description. {err_str}") - self.tile_description = tile_description - - operation = GemmOperationGrouped( - arch=self.current_cc, - tile_description=tile_description, - A=tensor_A, B=tensor_B, C=tensor_C, - epilogue_functor=self.epilogue_functor, - swizzling_functor=self._swizzling_functor, - precompute_mode=SchedulerMode.Device) - - return operation - - def run(self, A, B, C, D, - alpha=None, beta=None, sync: bool = True, - print_module: bool = False, - stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments: - """ - Runs the kernel currently specified. - - By default, this call returns only once the kernel has completed. To launch the kernel - and immediately return, set ``sync=False``. In this case, it is the responsibility of the - caller to syncrhonize the results of the kernel before attempting to access outputs - by calling ``sync()`` on the arguments returned from this call. - - :param A: list of tensors representing data type and layout of operand A - :type A: list - :param B: list of tensors representing data type and layout of operand B - :type B: list - :param C: list of tensors representing data type and layout of operand C - :type C: list - :param D: list of tensors representing data type and layout of operand D - :type D: list - :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B - :param beta: scalar parameter beta from GEMM operation that scales operand C - :param sync: whether the call should wait for the kernel to complete before returning - :type sync: bool - :param print_module: whether to print the emitted C++ code - :type print_module: bool - :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) - :type stream: :class:`cuda.cuda.CUstream` - - :return: arguments passed in to the kernel - :rtype: cutlass_cppgen.backend.GemmGroupedArguments - """ - if not stream: - stream = cuda.CUstream(0) - - super().run_setup() - - if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): - raise Exception("Lengths of A, B, C, and D lists must be equal") - - problem_sizes = [] - As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4)) - for i in range(len(A)): - As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A") - Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B") - Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C") - Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D") - problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1])) - - alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") - beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") - - alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As)) - alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs)) - alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs)) - self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, - alignment_C=alignment_c, print_module=print_module) - - arguments = GemmGroupedArguments( - operation=self.operation, - problem_sizes=problem_sizes, - A=As, B=Bs, C=Cs, D=Ds, - output_op=self.operation.epilogue_type(alpha, beta), - stream=stream - ) - - self.operation.run(arguments) - - if sync: - arguments.sync() - - return arguments diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py deleted file mode 100644 index bebf07a7e5b83a1cf14cfecf19e90f730e305dce..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py +++ /dev/null @@ -1,431 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) -""" - -from bisect import bisect_left - -from cutlass_library import ( - DataType, - DataTypeSize, - MathOperation, - OperationKind, - SharedMemPerCC -) - -import cutlass_cppgen -from cutlass_cppgen import get_option_registry -from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor -from cutlass_cppgen.backend.evt.passes.util import cc_map -from cutlass_cppgen.backend.utils.device import device_cc -from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity -from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs -from cutlass_cppgen.swizzle import get_swizzling_functors -from cutlass_cppgen.utils import datatypes, check - - -class OperationBase: - """ - Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) - """ - - def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm): - """ - :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 - :type cc: int - :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 - :type kernel_cc: int - :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv) - :type operation_kind: cutlass_library.OperationKind - """ - self.operation_kind = operation_kind - self.cc = cc if cc is not None else device_cc() - self.specified_kernel_cc = kernel_cc is not None - self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) - self.tile_description = None - self._math_operation = None - - self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) - - if self.options is None: - raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") - - # Default activation function: identity - self._activation = identity - - def _find_closest_cc(self, cc: int) -> int: - """ - Returns the closest CC in _generator_ccs less than or equal to `cc` - - :param cc: compute capability to query - :type cc: int - - :returns: closest CC in _generator_ccs less than or equal to `cc` - :rtype: int - """ - if cc in _generator_ccs: - return cc - - # Find closest CC lower than this CC - idx = bisect_left(_generator_ccs, cc) - if idx == 0: - raise Exception(f'No valid CC to fall back to for {cc}') - return _generator_ccs[idx-1] - - def activations(self) -> list: - """ - Returns possible activation functions that can be used - - :return: list of activation functions that can be used - :rtype: list - """ - return get_activations() - - def swizzling_functors(self) -> list: - """ - Returns possible swizzling functions that can be used - - :return: list of swizzling functions that can be used - :rtype: list - """ - return get_swizzling_functors() - - def _reset_options(self, cc: int): - """ - Resets the kernel options based on cc - - :param cc: compute capability to reset to - :type cc: int - """ - if cc != self.current_cc: - if cc not in _generator_ccs: - raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') - self.current_cc = cc - self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind) - - def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): - """ - Verifies the following properties: - 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) - 2) If ``scalar`` is not ``None``, its datatype must match matches the current version - set by the plan (i.e., those in ``ref_dtype``) - - If either of these properties does not hold, an exception is raised. If these properties hold and - ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. - - :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type scalar: numpy/cupy/torch scalar - :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in - :type ref_scalar: numpy/cupy/torch scalar - :param ref_dtype: data type for the scalar that this object was initialized to - :param name: identifier of the scalar to verify. Used in raising exceptions - :type name: str - - :return: valid scalar to use - :rtype: numpy/cupy/torch scalar - """ - if scalar is None: - if ref_scalar is None: - raise Exception(f"Scalar {name} must be set.") - return ref_scalar - if hasattr(scalar, "dtype"): - dtype = datatypes.library_type(scalar.dtype) - if dtype != ref_dtype: - raise Exception( - f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." - ) - return scalar - - def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): - """ - Verifies the following properties: - If ref_dtype is not void: - 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) - 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions - set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) - If ref_dtype is void: - Neither ``tensor`` nor ``ref_tensor`` are set - - If either of these properties does not hold, an exception is raised. If these properties hold and - ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. - - :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type tensor: numpy/cupy/torch array/tensor object - :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in - :type ref_tensor: numpy/cupy/torch array/tensor object - :param ref_dtype: data type for the tensor that this object was initialized to - :param ref_layout: layout for the tensor that this object was initialized to - :param name: identifier of the tensor to verify. Used in raising exceptions - :type name: str - - :return: valid tensor object to use - :rtype: numpy/cupy/torch array/tensor object - """ - if ref_dtype == DataType.void: - if tensor is not None or ref_tensor is not None: - raise Exception("Operands with element DataType.void must not be provided a tensor") - return None - - if tensor is None: - if ref_tensor is None: - raise Exception(f"Tensor {name} must be set.") - return ref_tensor - - self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) - return tensor - - @property - def opclass(self) -> cutlass_cppgen.OpcodeClass: - """ - Returns the opcode class currently in use - - :return: opcode class currently in use - :rtype: cutlass_cppgen.OpcodeClass - """ - return self.op_class - - @opclass.setter - def opclass(self, oc: cutlass_cppgen.OpcodeClass): - if isinstance(oc, str): - oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc) - if oc in self.possible_op_classes: - self.op_class = oc - else: - raise Exception( - f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' - f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' - f'layout combination ({self._layout_a}, {self._layout_b}).') - - # Changing the op class also changes the possible operations available. Reset these. - self.possible_operations = self.options.operations( - self.op_class, self._element_a, self._element_b, - self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) - - # Changing the op class changes the elements per access in the epilogue. Reset this. - if self.epilogue_functor is not None: - self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) - - @property - def math_operation(self) -> cutlass_cppgen.MathOperation: - """ - Returns the math operation currently in use - - :return: math operation currently in use - :rtype: cutlass_cppgen.MathOperation - """ - return self._math_operation - - @math_operation.setter - def math_operation(self, mo: cutlass_cppgen.MathOperation): - if isinstance(mo, str): - mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) - - if not self.specified_kernel_cc: - if self.current_cc in [90, 100, 101, 103]: - # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we - # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. - cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") - self._reset_options(80) - self._reset_operations(reset_epilogue=False) - elif self.current_cc in [90, 100, 101, 103]: - raise Exception("CUTLASS 3.0 kernels do not use different math operations. " - "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" - "parameter when constructing the plan.") - - self._math_operation = mo - self._reset_operations() - - def _elements_per_access(self): - if self.op_class == cutlass_cppgen.OpcodeClass.Simt: - return 1 - elif self._element_c != DataType.void: - return 128 // DataTypeSize[self._element_c] - else: - return 128 // max(self.possible_operations.alignments("C")) - - def _create_epilogue_functor_activation(self, activation): - """ - Returns the epilogue functor with given activation function - """ - if self.epilogue_functor is None: - elements_per_access = self._elements_per_access() - else: - elements_per_access = self.epilogue_functor.epilogue_vector_length - - if not self.specified_kernel_cc: - if self.current_cc in [90, 100, 101, 103] and activation != identity: - # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, - # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. - cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") - if self._element_c != self._element_d: - raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") - self._reset_options(80) - self._reset_operations(reset_epilogue=False) - elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): - # SM80 fallback kernels are currently used. Since an identity activation is requested, - # we can switch back to using SM90 kernels. - self._reset_options(self.cc) - self._reset_operations(reset_epilogue=False) - else: - if self.current_cc in [90, 100, 101, 103] and activation != identity: - raise Exception("Epilogues with elementwise fusion are not currently supported " - "in the Python interface for 3.x kernels. To use 2.x kernels " - "with fused elementwise epilogues, do not set the `kernel_cc` " - "parameter when constructing the plan.") - - return get_activation_epilogue( - activation, - self._element_d, - elements_per_access, - self._element_accumulator, - self._element_accumulator, - ) - - def _reset_epilogue_functor_activation(self, activation): - """ - Set the epilogue functor based on the provided activation function - """ - self.epilogue_functor = self._create_epilogue_functor_activation(activation) - - def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): - """ - Reset the alignment of the current epilogue functor based on alignment C - """ - if isinstance(epilogue_functor, EpilogueFunctorVisitor): - return epilogue_functor - - if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): - # Identity epilogue does not have 'activation_functor' - activation = identity - else: - activation = epilogue_functor.activation_functor - - epilogue_functor = get_activation_epilogue( - activation, - self._element_d, - alignment, - self._element_accumulator, - self._element_accumulator, - ) - return epilogue_functor - - @property - def activation(self): - """ - Returns the type of the current activation function used - """ - if hasattr(self.epilogue_functor, "activation_functor"): - return self.epilogue_functor.activation_functor - else: - return identity - - @activation.setter - def activation(self, act): - """ - Sets the type of the activation function to use - Activation can come with a set of arguments - - :param act: type of activation function to use - :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01) - - """ - if isinstance(act, tuple): - if isinstance(act[0], str): - act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0]) - else: - act_fn = act[0] - self._reset_epilogue_functor_activation(act_fn) - self._activation_args = act[1] - self._activation = act[0] - else: - if isinstance(act, str): - act = getattr(cutlass_cppgen.backend.epilogue, act) - self._reset_epilogue_functor_activation(act) - self._activation = act - - @property - def epilogue_visitor(self): - """ - Return the epilogue functor - """ - return self.epilogue_functor - - @epilogue_visitor.setter - def epilogue_visitor(self, visitor): - """ - Create the epilogue visitor - """ - self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) - - # The epilogue_functor may consume too much shared memory - # Reset the possible operations - if self.cc not in [90, 100, 101, 103]: - # The shared memory is only a concern for sm90+ epilogue - # In sm80, the epilogue and mainloop share the shared memory - return - - datatype_comb = self.possible_operations.datatype_comb - layout_comb = self.possible_operations.layout_comb - new_possible_operations = KernelsForDataType(datatype_comb, layout_comb) - for operation in self.possible_operations.all_operations: - td = datatypes.td_from_profiler_op(operation) - # Filter invalid epilogue schedules - if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ - cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, - cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: - continue - epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) - - # Verify the maximum number of mainloop stages - mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm) - smem_capacity_bytes = SharedMemPerCC[self.cc] << 10 - mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage - if mainloop_stages < 2: - # Mainloop stages must >= 2 - continue - - new_possible_operations.add(operation) - if len(new_possible_operations.all_operations) == 0: - raise RuntimeError( - "The epilogue consumes too much shared memory. " - "No valid tile description is found in the generator.") - self.possible_operations = new_possible_operations - - - def run_setup(self): - """ - Steps that must be taken before caling `plan.run()` - """ - # Initialize the memory pool if, if not already done - cutlass_cppgen.get_memory_pool() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py deleted file mode 100644 index a718f9bb4432f1f51457661abe27e24ea818aba4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py +++ /dev/null @@ -1,184 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for expressing shapes -""" - -from cutlass_library import ( - ConvMode, - ConvKind, - LayoutType -) -from cutlass_cppgen.backend.c_types import ( - Conv2DProblemSize_, - GemmCoord_, - GemmCoordBatched_ -) - - -class MatrixCoord: - def __init__(self, row, col): - self._row = row - self._col = col - - @property - def row(self): - return self._row - - @property - def column(self): - return self._col - - def leading_dimension(self, layout: LayoutType) -> int: - """ - Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord. - - :param layout: layout of matrix - :type layout: cutlass_library.LayoutType - - :returns: leading dimension - :rtype: int - """ - if layout == LayoutType.RowMajor: - return self._col - elif layout == LayoutType.ColumnMajor: - return self._row - else: - raise Exception(f'Unsupported layout for leading dimension calculation: {layout}') - - -class GemmCoord: - def __init__(self, m: int, n: int, k: int): - self._m = m - self._n = n - self._k = k - - @property - def m(self) -> int: - return self._m - - @property - def n(self) -> int: - return self._n - - @property - def k(self) -> int: - return self._k - - @property - def mk(self) -> MatrixCoord: - return MatrixCoord(self._m, self._k) - - @property - def mn(self) -> MatrixCoord: - return MatrixCoord(self._m, self._n) - - @property - def kn(self) -> MatrixCoord: - return MatrixCoord(self._k, self._n) - - @property - def ctype(self) -> GemmCoord_: - return GemmCoord_(self._m, self._n, self._k) - - def batched_ctype(self, batch_count: int) -> GemmCoordBatched_: - return GemmCoordBatched_(self._m, self._n, self._k, batch_count) - - -class Conv2DProblemSize: - def __init__( - self, n: int, h: int, w: int, c: int, - k: int, r: int, s: int, c_: int, - pad_h: int, pad_w: int, stride_h: int, stride_w: int, - dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation, - split_k_slices: int=1, groups: int=1): - - self.N = n - self.H = h - self.W = w - self.C = c - self.K = k - self.R = r - self.S = s - self.pad_h = pad_h - self.pad_w = pad_w - self.stride_h = stride_h - self.stride_w = stride_w - self.dilation_h = dilation_h - self.dilation_w = dilation_w - self.mode = int(mode) - self.split_k_slices = split_k_slices - self.groups = groups - self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1 - self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1 - - @property - def ctype(self) -> Conv2DProblemSize_: - return Conv2DProblemSize_(self) - - def implicit_gemm_size(self, kind: ConvKind): - if kind == ConvKind.Fprop: - return GemmCoord( - self.N * self.P * self.Q, - self.K, - self.R * self.S * self.C // self.groups - ) - elif kind == ConvKind.Dgrad: - return GemmCoord( - self.N * self.H * self.W, - self.C, - self.R * self.S * self.K - ) - elif kind == ConvKind.Wgrad: - return GemmCoord( - self.K, - self.R * self.S * self.C, - self.N * self.P * self.Q - ) - - @staticmethod - def from_sizes(input_size, weight_size): - K, R, S, _ = weight_size - pad_h = R // 2 - pad_w = S // 2 - stride_h = 1 - stride_w = 1 - dilation_h = 1 - dilation_w = 1 - return Conv2DProblemSize( - *input_size, - *weight_size, - pad_h, pad_w, - stride_h, stride_w, - dilation_h, dilation_w - ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py deleted file mode 100644 index ffd9483415ea36716bf4643d27b8d92f3e9878a5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py +++ /dev/null @@ -1,65 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Registry of swizzling functions -""" - -from cutlass_library import SwizzlingFunctor - - -IdentitySwizzle1 = SwizzlingFunctor.Identity1 -IdentitySwizzle2 = SwizzlingFunctor.Identity2 -IdentitySwizzle4 = SwizzlingFunctor.Identity4 -IdentitySwizzle8 = SwizzlingFunctor.Identity8 -HorizontalSwizzle = SwizzlingFunctor.Horizontal -ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK -StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1 -StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4 -StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal - - -_swizzling_functors = [ - IdentitySwizzle1, - IdentitySwizzle2, - IdentitySwizzle4, - IdentitySwizzle8, - HorizontalSwizzle, - ThreadblockSwizzleStreamK, - StridedDgradIdentitySwizzle1, - StridedDgradIdentitySwizzle4, - StridedDgradHorizontalSwizzle, -] - - -def get_swizzling_functors(): - return _swizzling_functors diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py deleted file mode 100644 index 75d8416a15070ddcf2c6270248ccd9deff8e2137..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_cppgen.utils.check import ( - alignment_or_default, - calculate_smem_usage, - calculate_smem_usage_per_stage, - valid_cluster_shape, - valid_schedule, - valid_stage_count, - update_alignment, -) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py deleted file mode 100644 index 108f268b4bc54ec0839afb5c1602ba63e5b98743..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py +++ /dev/null @@ -1,262 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility functions for checking constraints on kernels and calculating kernel attributes -""" - -import ctypes - -from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC - -import cutlass_cppgen -from cutlass_cppgen.backend.library import TileDescription - - -def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: - """ - Returns the amount of shared memory in bytes consumed in a single stage of a kernel. - - :param td: tile description to compute shared memory of - :type td: TileDescription - :param operation_kind: identifier for the type of operation being performed - :type operation_kind: cutlass_library.OperationKind - - :return: number of bytes of shared memory consumed by a single stage - :rtype: int - """ - m, n, k = td.blackwell_threadblock_shape - if td.is_2sm: - m //= 2 - - if operation_kind == OperationKind.Gemm: - stage_barrier_bytes = 32 - return ( - (DataTypeSize[td.math_instruction.element_a] * m * k // 8) - + (DataTypeSize[td.math_instruction.element_b] * k * n // 8) - + stage_barrier_bytes - ) - else: - raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") - - -def calculate_smem_usage(operation) -> int: - """ - Returns the amount of shared memory in bytes consumed by a kernel. - - :return: number of bytes of shared memory consumed by the operation - :return: int - """ - _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind) - return _per_stage * operation.tile_description.stages - - -def valid_stage_count( - cc: int, - kernel_cc: int, - td: TileDescription, - element_C: cutlass_cppgen.DataType = None, - element_D: cutlass_cppgen.DataType = None, - verbose: bool = True) -> tuple: - """ - Checks whether a device with `cc` supports the number of stages within `tile_description`, both - based on raw limits on the number of stages and based on shared memory capacity - - :param cc: compute capability of device in question - :type cc: int - :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS) - :type kernel_cc: int - :param td: tile description to check - :type td: TileDescription - :param element_C: data type of operand C - :type element_C: cutlass_cppgen.DataType - :param element_D: data type of operand D - :type element_D: cutlass_cppgen.DataType - :param verbose: whether to log warnings - :type verbose: bool - - :return: tuple with the first element indicating whether the provided tile description is - valid for the provided device and the second element being an error message - :rtype: tuple - """ - if kernel_cc in [90, 100, 101, 103]: - if (td.stages is None or td.stages == 0): - # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically - # determines the stage count to use. Thus, all settings are valid in these scenarios. - return (True, "") - elif verbose: - cutlass_cppgen.logger.warning( - "Setting an explicit stage count for SM90 kernels currently may " - "result in compilation errors if the combination of tile shape, " - "stage count, and shared memory requirement of the epilogue exceeds " - "the available shared memory per SM.") - - if td.stages <= 0: - return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") - - if cc < 80 and td.stages != 2: - return (False, f"Tile description has stage count of {td.stages}, " - f"but only 2 stages are supported on SM{cc}.") - - # The calculation below does not consider shared memory used by the epilogue and, thus, - # only catches cases in which the mainloop exceeds the device's shared memory capacity. - # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the - # mainloop and epilogue is shared. - smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm) - smem_usage_mainloop = (smem_per_stage * td.stages) - smem_arch = SharedMemPerCC[cc] << 10 - if smem_usage_mainloop > smem_arch: - return ( False, - "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" - f"Details:\n" - f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and " - f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n" - f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.") - - return (True, "") - - -def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: - """ - Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`. - - :param cc: compute capability of device in question - :type cc: int - :param cluster_shape: dimensions of thread block cluster shape to check - :type cluster_shape: list - - :return: tuple with the first element indicating whether the provided cluster shape is - valid for the provided device and the second element being an error message - :rtype: tuple - """ - - if cc < 90 or cc in [120, 121]: - if cluster_shape != [1, 1, 1]: - return (False, - f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " - f"{cluster_shape} for SM{cc}.") - else: - return (True, "") - - if len(cluster_shape) != 3: - return (False, - f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}") - - if cluster_shape[2] != 1: - return (False, - "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " - f"Received cluster shape of {cluster_shape}.") - - return (True, "") - - -def valid_schedule( - cc: int, - kernel_schedule: cutlass_cppgen.KernelScheduleType, - epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, - tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple: - """ - Checks that the kernel and epilogue schedules passed in are a valid combination for - a device of compute capability ``cc``. - - :param cc: compute capability of device in question - :type cc: int - :param kernel_schedule: kernel schedule type - :type kernel_schedule: cutlass_cppgen.KernelScheduleType - :param epilogue_schedule: epilogue schedule type - :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType - :param tile_scheduler: tile scheduler type - :type tile_scheduler: cutlass_cppgen.TileSchedulerType - - :return: tuple with the first element indicating whether the provided schedules are - valid for the provided device and the second element being an error message - :rtype: tuple - """ - kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) - epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) - tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) - if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): - return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") - - if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): - return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") - - if not tile_scheduler_default: - cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, - cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] - if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): - return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") - return (True, "") - - -def alignment_or_default(alignment_provided: int, default_alignment: int) -> int: - """ - Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks - that `alignment_provided` does not exceed `default_alignment`. - - :param alignment_provided: alignment preference specified. Can be None. - :type alignment_provided: int - :param default_alignment: alignment to use if `alignment_provided` is None - :type default_alignment: int - - :return: alignment to use - :rtype: int - """ - if alignment_provided is not None: - if alignment_provided > default_alignment: - raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") - return alignment_provided - - return default_alignment - - -def update_alignment(alignment_provided:int, default_alignment: int) -> int: - """ - Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks - that `alignment_provided` does not exceed `default_alignment`. - - :param alignment_provided: alignment preference specified. Can be None. - :type alignment_provided: int - :param default_alignment: alignment to use if `alignment_provided` is None - :type default_alignment: int - - :return: alignment to use - :rtype: int - """ - if alignment_provided is not None: - if alignment_provided > default_alignment: - if alignment_provided % default_alignment == 0: - return default_alignment - raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") - return alignment_provided - - return default_alignment diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py deleted file mode 100644 index c03a834dc47871bebe618752e4775a0a7434ff78..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py +++ /dev/null @@ -1,362 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility functions for converting between frontend datatypes and CUTLASS datatypes -""" - -import cutlass_cppgen -from cutlass_library import ( - DataTypeSize, - MathOperation, - MathInstruction -) -from cutlass_cppgen.backend.library import ( - TileDescription, -) - -bfloat16_available = None -cupy_available = None -numpy_available = None -torch_available = None -_library_to_cupy_dict = None -_library_to_numpy_dict = None -_library_to_torch_dict = None -_torch_to_library_dict = None - - -def is_numpy_available(): - global numpy_available, _library_to_numpy_dict - if numpy_available is None: - try: - import numpy as np - - numpy_available = True - _library_to_numpy_dict = { - cutlass_cppgen.DataType.f16: np.float16, - cutlass_cppgen.DataType.f32: np.float32, - cutlass_cppgen.DataType.f64: np.float64, - cutlass_cppgen.DataType.s8: np.int8, - cutlass_cppgen.DataType.s32: np.int32, - } - except ImportError: - numpy_available = False - _library_to_numpy_dict = {} - return numpy_available - - -def is_numpy_tensor(inp) -> bool: - if is_numpy_available(): - import numpy as np - return isinstance(inp, np.ndarray) - return False - - -def numpy_library_type(inp) -> cutlass_cppgen.DataType: - if is_numpy_available(): - import numpy as np - if inp == np.float16: - return cutlass_cppgen.DataType.f16 - elif inp == np.float32: - return cutlass_cppgen.DataType.f32 - elif inp == np.float64: - return cutlass_cppgen.DataType.f64 - elif inp == np.int8: - return cutlass_cppgen.DataType.s8 - elif inp == np.int32: - return cutlass_cppgen.DataType.s32 - return None - - -def numpy_type(inp): - return _library_to_numpy_dict.get(inp, None) - - -def is_cupy_available(): - global cupy_available - if cupy_available is None: - try: - import cupy as cp - - cupy_available = True - _library_to_cupy_dict = { - cutlass_cppgen.DataType.f16: cp.float16, - cutlass_cppgen.DataType.f32: cp.float32, - cutlass_cppgen.DataType.f64: cp.float64, - cutlass_cppgen.DataType.s8: cp.int8, - cutlass_cppgen.DataType.s32: cp.int32, - } - except ImportError: - cupy_available = False - _library_to_cupy_dict = {} - return cupy_available - - -def is_cupy_tensor(inp) -> bool: - if is_cupy_available(): - import cupy as cp - return isinstance(inp, cp.ndarray) - return False - - -def cupy_library_type(inp) -> cutlass_cppgen.DataType: - if is_cupy_available(): - import cupy as cp - if inp == cp.float16: - return cutlass_cppgen.DataType.f16 - elif inp == cp.float32: - return cutlass_cppgen.DataType.f32 - elif inp == cp.float64: - return cutlass_cppgen.DataType.f64 - return None - - -def cupy_type(inp): - return _library_to_cupy_dict.get(inp, None) - - -def is_torch_available(): - global torch_available, _library_to_torch_dict, _torch_to_library_dict - if torch_available is None: - try: - import torch - - torch_available = True - _torch_to_library_dict = { - torch.half: cutlass_cppgen.DataType.f16, - torch.float16: cutlass_cppgen.DataType.f16, - torch.bfloat16: cutlass_cppgen.DataType.bf16, - torch.float: cutlass_cppgen.DataType.f32, - torch.float32: cutlass_cppgen.DataType.f32, - torch.double: cutlass_cppgen.DataType.f64, - torch.float64: cutlass_cppgen.DataType.f64, - torch.int8: cutlass_cppgen.DataType.s8, - torch.int32: cutlass_cppgen.DataType.s32, - torch.uint8: cutlass_cppgen.DataType.u8, - } - - _library_to_torch_dict = { - cutlass_cppgen.DataType.f16: torch.half, - cutlass_cppgen.DataType.f16: torch.float16, - cutlass_cppgen.DataType.bf16: torch.bfloat16, - cutlass_cppgen.DataType.f32: torch.float, - cutlass_cppgen.DataType.f32: torch.float32, - cutlass_cppgen.DataType.f64: torch.double, - cutlass_cppgen.DataType.f64: torch.float64, - cutlass_cppgen.DataType.s8: torch.int8, - cutlass_cppgen.DataType.s32: torch.int32, - cutlass_cppgen.DataType.u8: torch.uint8, - } - - def possibly_add_type(torch_type_name, cutlass_type): - # Only try adding the type if the version of torch being used supports it - if hasattr(torch, torch_type_name): - torch_type = getattr(torch, torch_type_name) - _torch_to_library_dict[torch_type] = cutlass_type - _library_to_torch_dict[cutlass_type] = torch_type - - possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3) - possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2) - - except ImportError: - torch_available = False - _torch_to_library_dict = {} - _library_to_torch_dict = {} - return torch_available - - -def is_torch_tensor(inp) -> bool: - if is_torch_available(): - import torch - return isinstance(inp, torch.Tensor) - return False - - -def torch_library_type(inp) -> cutlass_cppgen.DataType: - return _torch_to_library_dict.get(inp, None) - - -def torch_type(inp): - return _library_to_torch_dict.get(inp, None) - - -def is_bfloat16_available(): - global bfloat16_available - - if bfloat16_available is None: - try: - import bfloat16 - - bfloat16_available = True - except ImportError: - bfloat16_available = False - return bfloat16_available - - -def bfloat16_library_type(inp) -> cutlass_cppgen.DataType: - if is_bfloat16_available(): - import bfloat16 - if inp == bfloat16.bfloat16: - return cutlass_cppgen.DataType.bf16 - - -def bfloat16_type(inp): - if is_bfloat16_available(): - import bfloat16 - if inp == cutlass_cppgen.DataType.bf16: - return bfloat16.bfloat16 - - -def library_type(inp): - if inp in DataTypeSize: - return inp - - for cvt_fn in [ - bfloat16_library_type, - cupy_library_type, - numpy_library_type, - torch_library_type, - ]: - out = cvt_fn(inp) - if out is not None: - return out - - raise Exception(f"No available conversion from type {inp} to a library type.") - - -def _tensor_from_numpy(np_tensor): - dtype = library_type(np_tensor.dtype) - if np_tensor.flags.c_contiguous: - layout = cutlass_cppgen.LayoutType.RowMajor - elif np_tensor.flags.f_contiguous: - layout = cutlass_cppgen.LayoutType.ColumnMajor - return (dtype, layout) - - -def _tensor_from_torch(pt_tensor): - dtype = library_type(pt_tensor.dtype) - return (dtype, cutlass_cppgen.LayoutType.RowMajor) - - -def get_datatype_and_layout(tensor): - if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): - return _tensor_from_numpy(tensor) - elif is_torch_tensor(tensor): - return _tensor_from_torch(tensor) - elif isinstance(tensor, float) or isinstance(tensor, int): - return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor) - else: - raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") - - -def get_tensor_shape(tensor, op="GEMM"): - if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): - return tensor.shape - elif is_torch_tensor(tensor): - size = tensor.size() - if op == "CONV": - # PyTorch Tensors have shape NCHW - return (size[0], size[2], size[3], size[1]) - else: - return tuple(tensor.size()) - elif isinstance(tensor, float) or isinstance(tensor, int): - return (1,) - else: - raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") - - -_math_operation_value_map = {x.value: x for x in MathOperation} - - -def backend_math_operation(math_op: MathOperation): - if math_op.value not in _math_operation_value_map.keys(): - raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") - return _math_operation_value_map[math_op.value] - - -def construct_backend_td(td: cutlass_cppgen.TileDescription, - kernel_schedule: cutlass_cppgen.KernelScheduleType, - epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, - tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription: - mi = td.math_instruction - backend_mi = MathInstruction( - mi.instruction_shape, - mi.element_a, - mi.element_b, - mi.element_accumulator, - mi.opcode_class, - backend_math_operation(mi.math_operation) - ) - cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1] - return TileDescription(td.threadblock_shape, td.stages, td.warp_count, - backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler) - - -def td_from_profiler_op(op) -> TileDescription: - """ - Converts the profiler's TileDescription in ``op`` into the backend TileDescription - - :param op: profiler Operation - - :returns: backend TileDescription - :rtype: cutlass_cppgen.backend.TileDescription - """ - kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None - eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None - tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None - return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule) - - -def td_from_profiler_td(td: TileDescription) -> TileDescription: - """ - Converts the profiler's TileDescription into the backend TileDescription - - :param td: profiler TileDescription - :type td: cutlass_cppgen.TileDescription - - :returns: backend TileDescription - :rtype: cutlass_cppgen.backend.TileDescription - """ - return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) - - -def to_camel_case(snake_str): - return "".join(x.capitalize() for x in snake_str.lower().split("_")) - - -def getattr_enum(obj, attr_name): - # The attr_name is under the snake_case - camel_attr = to_camel_case(attr_name) - if hasattr(obj, camel_attr): - return getattr(obj, camel_attr) - else: - raise Exception(f"Invalid option: {attr_name}") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py deleted file mode 100644 index 16f6a185040f4c2f6167c6191c9bee766a92b1b9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py +++ /dev/null @@ -1,41 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -import importlib -from typing import Any - -def lazy_import(mod_name: str) -> Any: - class Lazy: - def __getattr__(self, name:str) -> Any: - module = importlib.import_module(mod_name) - return getattr(module, name) - - return Lazy() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py deleted file mode 100644 index f53b1567978d17f2eaec0208d896aafb296f033f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py +++ /dev/null @@ -1,196 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Profiler based on the cuda events -""" - -import re -import subprocess - -from cutlass_cppgen.utils.lazy_import import lazy_import -cuda = lazy_import("cuda.cuda") -cudart = lazy_import("cuda.cudart") -import numpy as np - -from cutlass_cppgen import CUTLASS_PATH -from cutlass_cppgen.backend.library import DataTypeSize -from cutlass_cppgen.op.op import OperationBase -from cutlass_cppgen.shape import GemmCoord -from cutlass_cppgen.utils.datatypes import is_numpy_tensor - - -class GpuTimer: - def __init__(self) -> None: - self.events = [ - cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], - cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], - ] - - def start(self, stream=None): - if not stream: - stream = cuda.CUstream(0) - - (err,) = cuda.cuEventRecord(self.events[0], stream) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - - def stop(self, stream=None): - if not stream: - stream = cuda.CUstream(0) - - (err,) = cuda.cuEventRecord(self.events[1], stream) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - pass - - def stop_and_wait(self, stream=None): - if not stream: - stream = cuda.CUstream(0) - - self.stop(stream) - if stream: - (err,) = cuda.cuStreamSynchronize(stream) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - else: - (err,) = cudart.cudaDeviceSynchronize() - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - - def duration(self, iterations=1): - err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1]) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"CUDA Error {str(err)}") - return duration / float(iterations) - - -class CUDAEventProfiler: - def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None: - self.arguments = op.run(*args, **kwargs) - self.operation = op.operation - self.warmup_iterations = warmup_iterations - self.iterations = iterations - self.timer = GpuTimer() - - # - # Cutlass Python Interface Profiler - # - - def __call__(self): - for _ in range(self.warmup_iterations): - self.operation.run(self.arguments) - - self.timer.start() - for _ in range(self.iterations): - self.operation.run(self.arguments) - - self.timer.stop_and_wait() - runtime = self.timer.duration(self.iterations) - return runtime - - # - # CUTLASS Profiler - # - - def run_cutlass_profiler(self): - alpha = 1.0 - beta = 1.0 - - profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler" - kernel_name = self.operation.procedural_name() - verification_providers = "device" - provider = "cutlass" - problem_size = self.arguments.problem_size - - if "cutlass3x" in kernel_name: - # cutlass3x generator only have column-major output - layout_name = self.operation.layout_name_3x() - if layout_name[-1] == "t": - new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"]) - problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) - kernel_name = kernel_name.replace(layout_name, new_layout_name) - - batch_count = self.arguments.batch_count - - cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \ - f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \ - f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\ - f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}" - - result = subprocess.getoutput(cmd) - - m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) - runtime = float(m.group("runtime")) - - m = re.search(r"Bytes:\s+(?P\d+)", result) - bytes = int(m.group("bytes")) - - m = re.search(r"FLOPs:\s+(?P\d+)", result) - flops = int(m.group("flops")) - - # check if the problem size matches - assert bytes == self.bytes(problem_size, batch_count, beta) - assert flops == self.flops(problem_size, batch_count, beta) - - return runtime - - def bytes(self, problem_size, batch_count=1, beta=0.0): - m = problem_size.m() - n = problem_size.n() - k = problem_size.k() - - bytes = ( - (DataTypeSize[self.operation.A.element] * m // 8) * k - + (DataTypeSize[self.operation.B.element] * n // 8) * k - + (DataTypeSize[self.operation.C.element] * m // 8) * n - ) - - if beta != 0: - bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n - - bytes *= batch_count - - return bytes - - def flops(self, problem_size, batch_count=1, beta=0.0): - m = problem_size.m() - n = problem_size.n() - k = problem_size.k() - - flops_ = (m * n * k) * 2 * batch_count - - if beta != 0: - flops_ += m * n * batch_count * 2 - - return flops_ - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py deleted file mode 100644 index 534eef47d810eb9f17a9ba6dbbe2e0dff935eb3f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import os -import sys - -from . import conv2d_operation -from . import conv3d_operation -from . import emit_kernel_listing -from . import gemm_operation - -if '-m' not in sys.argv: - # Do not import generator when running python -m cutlass_library.generator to - # avoid double-import warnings - from . import generator - -from . import library -from . import manifest -from . import rank_2k_operation -from . import rank_k_operation -from . import symm_operation -from . import trmm_operation -# Make enum types from library.py accessible via cutlass_library.* -from .library import * - -# Set up `source` to point to the path containing the CUTLASS source. -# Check first if the path contains a `source` subdirectory -- this will -# be the case when the package has been installed via pip. Otherwise, -# default to the root of CUTLASS. -install_source_path = os.path.join(__path__[0], 'source') -if os.path.isdir(install_source_path): - source_path = install_source_path -else: - source_path = os.path.join(__path__[0], '../..') diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py deleted file mode 100644 index b674463a2c5795be8610883c4dc98a1e7123a01b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py +++ /dev/null @@ -1,621 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting Conv2d kernels -""" - -import enum -import logging -import os.path -import shutil -from string import Template - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * - from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes -except ImportError: - from library import * - from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes - -_LOGGER = logging.getLogger(__name__) - -################################################################################################### - -# -class Conv2dOperation: - # - def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ - stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \ - group_mode = GroupMode.NoneGroup): - - self.operation_kind = OperationKind.Conv2d - self.arch = arch - self.tile_description = tile_description - self.conv_kind = conv_kind - self.A = A - self.B = B - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.iterator_algorithm = iterator_algorithm - self.stride_support = stride_support - self.swizzling_functor = swizzling_functor - self.group_mode = group_mode - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - intermediate_type = '' - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.accumulator_type(): - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - else: - inst_shape = '' - - return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ - inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - return "%s" % (ShortLayoutTypeNames[self.A.layout]) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - threadblock = self.tile_description.procedural_name() - - # grouped conv - if self.group_mode != GroupMode.NoneGroup: - group_conv_name = f"{GroupModeNames[self.group_mode]}_" - else: - group_conv_name = "" - - if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" - else: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" - - return SubstituteTemplate( - configuration_name, - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment': "%d" % self.A.alignment, - 'group_conv_name': group_conv_name - } - ) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.configuration_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -class EmitConv2dInstance: - def __init__(self): - # Emitter for CUTLASS 3 convolution operations - self.conv3x_emitter = EmitConv3xInstance() - self.template = """ - // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" - using ${operation_name}_base = - typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< - ${element_a}, - ${layout_a}, - ${element_b}, - ${layout_b}, - ${element_c}, - ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator}, - ${iterator_algorithm}, - ${stride_support}, - ${align_a}, - ${align_b} - >::Kernel; -""" - self.template_group_conv = """ - // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" - using ${operation_name}_base = - typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}< - ${element_a}, - ${layout_a}, - ${element_b}, - ${layout_b}, - ${element_c}, - ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator}, - ${group_mode}, - ${iterator_algorithm}, - ${stride_support}, - ${align_a}, - ${align_b} - >::Kernel; -""" - self.template_depthwise_direct_conv = """ - // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" - using ${operation_name}_base = - typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}< - ${element_a}, - ${layout_a}, - ${element_b}, - ${layout_b}, - ${element_c}, - ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>, - cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue}, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling - >, - - cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< - 1, - ${threadblock_output_shape_n}, - ${threadblock_output_shape_p}, - ${threadblock_output_shape_q}>, - ${stages}, - ${math_operator}, - ${iterator_algorithm}, - ${stride_support}, - cutlass::MatrixShape<${stride_r}, ${stride_s}>, - cutlass::MatrixShape<${dilation_r}, ${dilation_s}> - >::Kernel; -""" - - def arch_number_to_type(self, arch: int): - return f"cutlass::arch::Sm{arch}" - - def emit(self, operation): - _LOGGER.debug("*** EmitConv2dInstance::emit") - _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) - - if hasattr(operation, 'is_3x') and operation.is_3x: - _LOGGER.debug("*** CUTLASS 3 operation") - return self.conv3x_emitter.emit(operation) - - _LOGGER.debug("*** CUTLASS 2 operation") - - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'conv_kind': ConvKindTag[operation.conv_kind], - 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], - 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), - 'stride_support': StrideSupportTag[operation.stride_support], - 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \ - MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - } - - if operation.group_mode == GroupMode.NoneGroup: - _LOGGER.debug("*** group_mode=NoneGroup") - return SubstituteTemplate(self.template, values) - - elif operation.group_mode == GroupMode.Depthwise: - _LOGGER.debug("*** group_mode=Depthwise") - values['group_mode'] = GroupModeTag[operation.group_mode] - # Setup other template params - values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0]) - values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1]) - values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2]) - - values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3]) - - values['filter_shape_r'] = str(operation.tile_description.filter_shape[0]) - values['filter_shape_s'] = str(operation.tile_description.filter_shape[1]) - - values['stride_r'] = str(operation.tile_description.stride[0]) - values['stride_s'] = str(operation.tile_description.stride[1]) - - values['dilation_r'] = str(operation.tile_description.dilation[0]) - values['dilation_s'] = str(operation.tile_description.dilation[1]) - - return SubstituteTemplate(self.template_depthwise_direct_conv, values) - - else: - _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode]) - values['group_mode'] = GroupModeTag[operation.group_mode] - return SubstituteTemplate(self.template_group_conv, values) - -################################################################################################### -# -# Generator functions for all layouts -# -################################################################################################### - -# -def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128): - _LOGGER.debug("*** GenerateConv2dTensorOp") - - for tile in tile_descriptions: - for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: - - if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): - - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] - - for output_type in output_types: - A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) - B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) - C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type]))) - - manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) - -class EmitConv2dIncludes: - '''Emit includes that are specific to the operation.''' - - def __init__(self): - self.includes = ['conv2d_operation.h'] - self.emitter_3x = EmitConv3xIncludes() - - def operation_is_3x(self, operation) -> bool: - """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" - return hasattr(operation, 'is_3x') and operation.is_3x - - def emit(self, operation) -> str: - if self.operation_is_3x(operation): - return self.emitter_3x.emit(operation) - - return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ - "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitConv2dConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) - - self.instance_emitter = EmitConv2dInstance() - self.includes_emitter = EmitConv2dIncludes() - - self.header_template = """ -/* - Generated by conv2d_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -""" - - self.instance_template = """ -${stub_begin} -${operation_instance} -// Derived class -struct ${operation_name} : - public ${operation_name}_base { }; -${stub_end} -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.configuration_header = """ - -namespace cutlass { -namespace library { - -// Initialize all instances -void initialize_${configuration_name}(Manifest &manifest) { -""" - - self.configuration_instance = """${stub_begin} - using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< - ${operation_name}>; - - manifest.append(new cutlass::library::${operation_wrapper}< - Operation_${operation_name} - >( - "${operation_name}" - )); -${stub_end} -""" - - self.configuration_epilogue = "}\n" - - self.epilogue_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def operation_is_3x(self, operation): - """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" - return hasattr(operation, 'is_3x') and operation.is_3x - - def __enter__(self): - """ - Open the configuration_file, and write the "header" C++ code to it. - - The "header" consists of a comment (that this is generated code, - so it should not be edited), and includes that are common - to all kinds of kernels. - """ - _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__') - _LOGGER.debug('*** configuration_path (file to write): ' + - str(self.configuration_path)) - _LOGGER.debug('*** configuration_name: ' + self.configuration_name) - self.configuration_file = open(self.configuration_path, "w") - - self.configuration_file.write(SubstituteTemplate(self.header_template, { - 'configuration_name': self.configuration_name - })) - self.operations = [] - return self - - def emit(self, operation): - """ - Write three pieces of C++ code to the configuration_file - (that was opened by the __enter__ method above): - - 1. the header includes that are specific to the operation - (CUTLASS 2 vs. CUTLASS 3); - - 2. the "operation instance" (a "using" declaration ending in "_base"); and - - 3. the "operation name" (declaration and definition of a derived class - of the above operation instance). - - The "using" declaration turns a C++ class name, possibly namespace-qualified, - possibly also with angle brackets, into a C-style, easily demangled identifier. - """ - _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit') - _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) - self.operations.append(operation) - - self.configuration_file.write(self.includes_emitter.emit(operation)) - - stub_begin = '' - stub_end = '' - # It can be useful to stub (comment) out instantiations for testing. - # In this case, one need only set is_stub to True. - is_stub = False - if is_stub: - stub_begin = "// STUB for now\n#if 0" - stub_end = '#endif // 0' - - self.configuration_file.write(Template(self.instance_template).substitute({ - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'operation_instance': self.instance_emitter.emit(operation), - 'stub_begin': stub_begin, - 'stub_end': stub_end - })) - - def __exit__(self, exception_type, exception_value, traceback): - """ - Write the rest of the C++ code to the configuration_file, and close the file. - - The "rest of the C++ code" has the following components. - - 1. Configuration header: Open the namespace(s), and open the definition - of the "initialize_${configuration_name}" registration function - that registers the operation with the Manifest. - ("Registration" helps turn C++ compile-time polymorphism - (via template parameters) into a run-time choice of parameters.) - - 2. Configuration instance: In the body of the registration function, - make a "using" declaration Operation_${operation_name} for the - operation type (which uses operation_name as its template argument). - Then, tell the manifest about the operation via a "manifest.append" call. - The argument of the call is a new instance of - "SomethingOperation" - (replace Something with a specific name). - - 3. Configuration epilogue: Close the definition of the registration function. - - 4. Epilogue template: Close the namespace(s). - """ - - _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__') - _LOGGER.debug('*** configuration_path (file to write): ' + - str(self.configuration_path)) - _LOGGER.debug('*** configuration_name: ' + self.configuration_name) - - self.configuration_file.write(SubstituteTemplate(self.configuration_header, { - 'configuration_name': self.configuration_name - })) - - for operation in self.operations: - stub_begin = '' - stub_end = '' - # It can be useful to stub (comment) out instantiations for testing. - # In this case, one need only set is_stub to True. - is_stub = False - if is_stub: - stub_begin = "// STUB for now\n#if 0" - stub_end = "#endif // 0" - - if operation.group_mode == GroupMode.Depthwise: - kernel_name = 'DirectConvolution' - operation_wrapper = 'DirectConv2dOperation' - else: - kernel_name = 'ImplicitGemmConvolution' - operation_wrapper = 'Conv2dOperation' - if self.operation_is_3x(operation): - kernel_name = 'ConvUniversalAdapter' - operation_wrapper = 'ConvOperation3x' - - self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'kernel_name': kernel_name, - 'operation_wrapper': operation_wrapper, - 'stub_begin': stub_begin, - 'stub_end': stub_end - })) - - self.configuration_file.write(self.configuration_epilogue) - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - - -################################################################################################### -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py deleted file mode 100644 index b96b6db74224e52bd90b6e184a62624475385352..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py +++ /dev/null @@ -1,482 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting Conv3d kernels -""" - -import enum -import logging -import os.path -import shutil -from string import Template - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * - from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes -except ImportError: - from library import * - from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes - -_LOGGER = logging.getLogger(__name__) - -################################################################################################### - -# -class Conv3dOperation: - # - def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ - stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - - self.operation_kind = OperationKind.Conv3d - self.arch = arch - self.tile_description = tile_description - self.conv_kind = conv_kind - self.A = A - self.B = B - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.iterator_algorithm = iterator_algorithm - self.stride_support = stride_support - self.swizzling_functor = swizzling_functor - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - intermediate_type = '' - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - else: - inst_shape = '' - - return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \ - inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - threadblock = "%dx%d_%dx%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - self.tile_description.stages - ) - - if self.stride_support == StrideSupport.Unity: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride" - else: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}" - - return SubstituteTemplate( - configuration_name, - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - } - ) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.configuration_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -class EmitConv3dInstance: - def __init__(self): - # Emitter for CUTLASS 3 convolution operations - self.conv3x_emitter = EmitConv3xInstance() - self.template = """ - // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" - using ${operation_name}_base = - typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}< - ${element_a}, - cutlass::layout::TensorNDHWC, - ${element_b}, - cutlass::layout::TensorNDHWC, - ${element_c}, - cutlass::layout::TensorNDHWC, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, - ${stages}, - cutlass::arch::OpMultiplyAdd, - ${iterator_algorithm}, - ${stride_support} - >::Kernel; -""" - - def emit(self, operation): - _LOGGER.debug("*** EmitConv3dInstance::emit") - _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) - - if hasattr(operation, 'is_3x') and operation.is_3x: - _LOGGER.debug("*** CUTLASS 3 operation") - return self.conv3x_emitter.emit(operation) - - _LOGGER.debug("*** CUTLASS 2 operation") - - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'conv_kind': ConvKindTag[operation.conv_kind], - 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], - 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), - 'stride_support': StrideSupportTag[operation.stride_support] - } - - return SubstituteTemplate(self.template, values) - -################################################################################################### -# -# Generator functions for all layouts -# -################################################################################################### - -# -def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128): - - for tile in tile_descriptions: - for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: - - if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): - - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] - - for output_type in output_types: - A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) - B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) - C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type]))) - - manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) - -class EmitConv3dIncludes: - '''Emit includes that are specific to the operation.''' - - def __init__(self): - self.includes = ['conv3d_operation.h'] - self.emitter_3x = EmitConv3xIncludes() - - def operation_is_3x(self, operation) -> bool: - """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" - return hasattr(operation, 'is_3x') and operation.is_3x - - def emit(self, operation) -> str: - if self.operation_is_3x(operation): - return self.emitter_3x.emit(operation) - - return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ - "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitConv3dConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) - - self.instance_emitter = EmitConv3dInstance() - self.includes_emitter = EmitConv3dIncludes() - - self.header_template = """ -/* - Generated by conv3d_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -""" - - self.instance_template = """ -${stub_begin} -${operation_instance} -// Derived class -struct ${operation_name} : - public ${operation_name}_base { }; -${stub_end} -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.configuration_header = """ - -namespace cutlass { -namespace library { - -// Initialize all instances -void initialize_${configuration_name}(Manifest &manifest) { -""" - - self.configuration_instance = """${stub_begin} - using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< - ${operation_name}>; - - manifest.append(new cutlass::library::${operation_wrapper}< - Operation_${operation_name} - >( - "${operation_name}" - )); -${stub_end} -""" - - self.configuration_epilogue = "}\n" - - self.epilogue_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def operation_is_3x(self, operation): - """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" - return hasattr(operation, 'is_3x') and operation.is_3x - - def __enter__(self): - """ - Open the configuration_file, and write the "header" C++ code to it. - - The "header" consists of a comment (that this is generated code, - so it should not be edited), and includes that are common - to both the CUTLASS 2 and the CUTLASS 3 cases. - """ - _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__') - _LOGGER.debug('*** configuration_path (file to write): ' + - str(self.configuration_path)) - _LOGGER.debug('*** configuration_name: ' + self.configuration_name) - self.configuration_file = open(self.configuration_path, "w") - - self.configuration_file.write(SubstituteTemplate(self.header_template, { - 'configuration_name': self.configuration_name - })) - self.operations = [] - return self - - def emit(self, operation): - """ - Write three pieces of C++ code to the configuration_file - (that was opened by the __enter__ method above): - - 1. the header includes that are specific to the operation - (CUTLASS 2 vs. CUTLASS 3); - - 2. the "operation instance" (a "using" declaration ending in "_base"); and - - 3. the "operation name" (declaration and definition of a derived class - of the above operation instance). - - The "using" declaration turns a C++ class name, possibly namespace-qualified, - possibly also with angle brackets, into a C-style, easily demangled identifier. - """ - _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit') - _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) - self.operations.append(operation) - - self.configuration_file.write(self.includes_emitter.emit(operation)) - - stub_begin = '' - stub_end = '' - # It can be useful to stub (comment) out instantiations for testing. - # In this case, one need only set is_stub to True. - is_stub = False - if is_stub: - stub_begin = "// STUB for now\n#if 0" - stub_end = '#endif // 0' - - self.configuration_file.write(Template(self.instance_template).substitute({ - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'operation_instance': self.instance_emitter.emit(operation), - 'stub_begin': stub_begin, - 'stub_end': stub_end - })) - - def __exit__(self, exception_type, exception_value, traceback): - """ - Write the rest of the C++ code to the configuration_file, and close the file. - - The "rest of the C++ code" has the following components. - - 1. Configuration header: Open the namespace(s), and open the definition - of the "initialize_${configuration_name}" registration function - that registers the operation with the Manifest. - ("Registration" helps turn C++ compile-time polymorphism - (via template parameters) into a run-time choice of parameters.) - - 2. Configuration instance: In the body of the registration function, - make a "using" declaration Operation_${operation_name} for the - operation type (which uses operation_name as its template argument). - Then, tell the manifest about the operation via a "manifest.append" call. - The argument of the call is a new instance of - "SomethingOperation" - (replace Something with a specific name). - - 3. Configuration epilogue: Close the definition of the registration function. - - 4. Epilogue template: Close the namespace(s). - """ - - _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__') - _LOGGER.debug('*** configuration_path (file to write): ' + - str(self.configuration_path)) - _LOGGER.debug('*** configuration_name: ' + self.configuration_name) - - self.configuration_file.write(SubstituteTemplate(self.configuration_header, { - 'configuration_name': self.configuration_name - })) - - for operation in self.operations: - stub_begin = '' - stub_end = '' - # It can be useful to stub (comment) out instantiations for testing. - # In this case, one need only set is_stub to True. - is_stub = False - if is_stub: - stub_begin = "// STUB for now\n#if 0" - stub_end = "#endif // 0" - - kernel_name = 'ImplicitGemmConvolution' - operation_wrapper = 'Conv3dOperation' - if self.operation_is_3x(operation): - kernel_name = 'ConvUniversalAdapter' - operation_wrapper = 'ConvOperation3x' - - self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'kernel_name': kernel_name, - 'operation_wrapper': operation_wrapper, - 'stub_begin': stub_begin, - 'stub_end': stub_end - })) - - self.configuration_file.write(self.configuration_epilogue) - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - - -################################################################################################### -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py deleted file mode 100644 index 33d6da1a4675c0bbd07315717a7f5ba0ba0dc10c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py +++ /dev/null @@ -1,250 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting CUTLASS >= 3 convolution kernels -""" - -import enum -import os.path -import shutil -import logging -from string import Template - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - -_LOGGER = logging.getLogger(__name__) - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -class EmitConv3xInstance: - def __init__(self): - _LOGGER.debug("*** EmitConv3xInstance::__init__") - - # Define epilogue type first, so that the mainloop type - # can use it with StageCountAutoCarveout. - self.template = """ - -// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}" -using ${operation_name}_epilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ${arch}, - ${opcode_class_epi}, - ${mma_tile_shape}, // mma tile shape - ${cluster_shape}, // cluster shape - ${epi_tile_mn}, - ${element_accumulator}, - ${element_compute}, - ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>, - ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>, - ${epilogue_schedule} - // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination - >::CollectiveOp; - -using ${operation_name}_mainloop = - typename cutlass::conv::collective::CollectiveBuilder< - ${arch}, - ${opcode_class_main}, - ${conv_kind}, // kFprop, kDgrad, or kWgrad - ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>, - ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>, - ${element_accumulator}, - ${mma_tile_shape}, // mma tile shape - ${cluster_shape}, // cluster shape - ${stages}, - ${kernel_schedule} - >::CollectiveOp; - -using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>; - -// Unit tests call this "ConvKernel". -// Conv operator ${operation_name} -using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< - ${operation_name}_problem_shape, - ${operation_name}_mainloop, - ${operation_name}_epilogue, - ${tile_scheduler} - >; -""" - - def arch_number_to_type(self, arch: int) -> str: - return f"cutlass::arch::Sm{arch}" - - def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: - mma_m = cta_m - mma_n = cta_n - mma_k = cta_k - - if operation.arch >= 100: - # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where - # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version. - # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated, - # otherwise 1sm kernel is allocated. - cta_m_per_mma_instruction = 1 - if "2sm" in operation.procedural_name() : - cta_m_per_mma_instruction = 2 - elif "1sm" in operation.procedural_name() : - cta_m_per_mma_instruction = 1 - elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 : - cta_m_per_mma_instruction = 2 - mma_m = cta_m * cta_m_per_mma_instruction - - # For all three kinds of convolutions, the tile shape's K mode - # differs from GEMM in that needs to be wrapped in a Shape. - # For Wgrad convolutions specifically, - # the N tile shape also needs to be wrapped in a Shape. - m_template = 'cute::_${mma_m}' - if operation.conv_kind == ConvKind.Wgrad: - n_template = 'cute::Shape' - else: - n_template = 'cute::_${mma_n}' - k_template = 'cute::Shape' - - mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' - values = { - 'mma_m': mma_m, - 'mma_n': mma_n, - 'mma_k': mma_k - } - return Template(mma_tile_shape_template).substitute(values) - - def cluster_shape(self, operation) -> str: - m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)' - n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)' - k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)' - cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' - values = { - 'cluster_shape_m': operation.tile_description.cluster_shape[0], - 'cluster_shape_n': operation.tile_description.cluster_shape[1], - 'cluster_shape_k': operation.tile_description.cluster_shape[2], - } - return Template(cluster_shape_template).substitute(values) - - def stage_count(self, operation) -> str: - # stages == 0 tells builder to pick the number of stages automatically - namespace_prefix = 'cutlass::conv::collective::' - if operation.tile_description.stages > 0: - return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>" - else: - return f"{namespace_prefix}StageCountAutoCarveout" - - def emit(self, operation) -> str: - _LOGGER.debug("*** EmitConv3xInstance::emit") - _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) - - # Identify the operation as CUTLASS 3 by its is_3x field - if (not hasattr(operation, 'is_3x')) or (not operation.is_3x): - raise RuntimeError("operation must be a CUTLASS 3 operation") - - epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" - opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] - opcode_class_epi = opcode_class_main - - tile_shape = operation.tile_description.tile_shape - cluster_m = operation.tile_description.cluster_shape[0] - cluster_n = operation.tile_description.cluster_shape[1] - - cta_m, cta_n, cta_k = tile_shape - # account for static/dynamic cluster shapes - if operation.arch >= 100: - cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m - cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n - - warp_count = operation.tile_description.warp_count - epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule] - - # KernelScheduleTag and TileSchedulerTag both hard-code the - # namespace qualification of KernelScheduleAuto as - # "cutlass::gemm::collective::" (unless the tag is 'void'). - # - # For TileSchedulerTag, this namespace is fine, since CUTLASS 3 - # convolutions use the same tile schedulers (from the same - # cutlass::gemm::collective namespace) as GEMMs. - kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::') - tile_scheduler = TileSchedulerTag[operation.tile_scheduler] - opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] - - values = { - 'operation_name': operation.procedural_name(), - 'conv_kind': ConvKindTag[operation.conv_kind], - 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'align_a': int(operation.A.alignment), - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'align_b': int(operation.B.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'align_c': int(operation.C.alignment), - 'element_d': DataTypeTag[operation.D.element], - 'layout_d': LayoutTag[operation.D.layout], - 'align_d': int(operation.D.alignment), - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': opcode_class, - 'arch': self.arch_number_to_type(operation.arch), - 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), - 'cluster_shape': self.cluster_shape(operation), - 'opcode_class_epi': opcode_class_epi, - 'opcode_class_main': opcode_class_main, - 'epi_tile_mn': epi_tile_mn, - 'stages': self.stage_count(operation), - 'kernel_schedule': kernel_schedule, - 'epilogue_schedule': epilogue_schedule, - 'tile_scheduler': tile_scheduler, - 'element_compute': DataTypeTag[operation.element_compute] - } - return Template(self.template).substitute(values) - -class EmitConv3xIncludes: - def __init__(self): - _LOGGER.debug("*** EmitConv3xIncludes::__init__") - self.includes = ['conv_operation_3x.hpp', - 'cutlass/conv/device/conv_universal_adapter.hpp', - 'cutlass/conv/kernel/conv_universal.hpp', - 'cutlass/conv/collective/collective_builder.hpp', - 'cutlass/epilogue/collective/collective_builder.hpp'] - - def emit(self, operation) -> str: - _LOGGER.debug("*** EmitConv3xIncludes::emit") - return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ - "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py deleted file mode 100644 index fbe52eb587ab1b5e4595739be5790151b00e0a70..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py +++ /dev/null @@ -1,868 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -# -# -# \brief Generates the CUTLASS kernel listing with kernel filtering -# - -# - -############################################################################### -# Example usage: -# generator.py --operations all --generator-target kernel_listing \ -# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports -############################################################################### - -import collections -import csv -import json -import math -import os - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - -audit_csv_fields = [ - "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD", - "Layout_A", "Layout_B", "Layout_C", "Layout_D", - "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D", - "1SM/2SM", - "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types", - "Test Counts" -] - -audit_csv_runtime_fields = [ - "KerneIndex", "KernelName", - "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K", - "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K", - "M", "N", "K", "L", "Alpha_val", "Beta_val", - "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled" -] - -def hash_cutlass_string(input_string): - mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') - - # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') - output = re.sub(mma_cluster_shape_pattern, "", input_string) - - return output - -def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): - # Define a dictionary mapping the detected types to runtime values - datatype_map = { - 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b, - 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b, - 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b, - 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b, - 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b, - 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b, - 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b, - 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b, - 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b, - 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b, - 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, - } - - # Regular expression to detect all the keys in datatype_map - pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')') - - # Replace detected patterns using the dictionary - updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name) - - return updated_kernel_name - -# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. -def get_kernel_features(operation, kernel_name, - dynamic_datatype, runtime_input_datatype): - numcta_inst = "2sm" if "2sm" in kernel_name else "1sm" - math_inst = operation.tile_description.math_instruction - - if dynamic_datatype: - dtype_name_A = runtime_input_datatype[0] - dtype_name_B = runtime_input_datatype[1] - else: - dtype_name_A = DataTypeNames[operation.A.element] - dtype_name_B = DataTypeNames[operation.B.element] - - layout_name_A = ShortLayoutTypeNames[operation.A.layout] - layout_name_B = ShortLayoutTypeNames[operation.B.layout] - layout_name_C = ShortLayoutTypeNames[operation.C.layout] - layout_name_D = ShortLayoutTypeNames[operation.D.layout] - - scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void - scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void) - audit_vals = [ - "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM", - kernel_name, - dtype_name_A, - dtype_name_B, - DataTypeNames[operation.C.element], - DataTypeNames[operation.tile_description.math_instruction.element_accumulator], - DataTypeNames[operation.element_epilogue], - DataTypeNames[operation.D.element], - DataTypeNames[scale_factor_D_type], - DataTypeNames[scale_factor_A_type], - layout_name_A, - layout_name_B, - layout_name_C, - layout_name_D, - str(operation.A.alignment), - str(operation.B.alignment), - str(operation.C.alignment), - str(operation.D.alignment), - numcta_inst, - "Y" if 'stream_k' in kernel_name else "N", - ] - return audit_vals - -# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta. -def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster): - math_inst = operation.tile_description.math_instruction - audit_vals = [ - str(math_inst.instruction_shape[0]), - str(math_inst.instruction_shape[1]), - str(math_inst.instruction_shape[2]), - str(operation.tile_description.threadblock_shape[0]), - str(operation.tile_description.threadblock_shape[1]), - str(operation.tile_description.threadblock_shape[2]), - str(operation.tile_description.cluster_shape[0]), - str(operation.tile_description.cluster_shape[1]), - str(operation.tile_description.cluster_shape[2]), - str(cluster_shape[0]), - str(cluster_shape[1]), - str(cluster_shape[2]), - str(fallback_cluster_shape[0]), - str(fallback_cluster_shape[1]), - str(fallback_cluster_shape[2]), - str(problem_shape[0]), - str(problem_shape[1]), - str(problem_shape[2]), - str(problem_shape[3]), - str(alpha), - str(beta), - "Y" if dynamic_datatype else "N", - "Y" if dynamic_cluster else "N", - ] - return audit_vals - - -def _getSubOperationType(kernel): - - if kernel.operation_kind == OperationKind.Gemm: - return GemmKindNames[kernel.gemm_kind] - elif kernel.operation_kind == OperationKind.Conv2d: - return "conv_" + ConvKindNames[kernel.conv_kind] - elif kernel.operation_kind == OperationKind.Syrk: - return "syrk_" + SyrkKindNames[kernel.syrk_kind] - elif kernel.operation_kind == OperationKind.Trmm: - return "trmm_" + TrmmKindNames[kernel.trmm_kind] - elif kernel.operation_kind == OperationKind.Symm: - return "symm_" + SymmKindNames[kernel.symm_kind] - else: - raise Exception("Unsupported kernel type") - -def _get_inst_shape(math_instruction): - return "".join(str(x) for x in math_instruction.instruction_shape) - -def _is_simt_inst(math_instruction): - return _get_inst_shape(math_instruction) in ["111","114"] - -def _getInstType(input_precision, accumulate_precision, math_instruction): - - # inst_shape - inst_shape = _get_inst_shape(math_instruction) - - # input precision - if input_precision == "fp32" and inst_shape != "111": - inp = "tf32" - else: - inp = input_precision - - # Handle SIMT op types first - if _is_simt_inst(math_instruction): - - simt_input_precision_to_inst = { - "fp32": "FFMA", - "fp64": "DFMA", - "fp16": "HFMA", - "int8": "IDP4A", - } - inst = simt_input_precision_to_inst[input_precision] - - else: # Tensor op instructions - - if accumulate_precision == "cf64": - fp64_acc_map = { - MathOperation.multiply_add_complex_gaussian : "gz", - MathOperation.multiply_add_complex : "z", - } - acc = fp64_acc_map[math_instruction.math_operation] - else: - tensor_op_acc_map = { - "fp32" : "s", - "cf32" : "s", - "fp16" : "h", - "int32": "i", - "fp64" : "d", - } - acc = tensor_op_acc_map[accumulate_precision] - - inst = "{}{}{}".format(acc, inst_shape, inp) - - return inst -# TODO: Computes FLOps/Bytes for GEMM - revisit for conv -def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1): - assert not (batch_count > 1 and num_groups > 1) - - # TODO: adjust for sparsity - gmem_bytes = ( - (DataTypeSize[operation.A.element] * m // 8) * k + - (DataTypeSize[operation.B.element] * n // 8) * k + - (DataTypeSize[operation.C.element] * m // 8) * n - ) - - # TODO: complex-valued support - flops = 2 * (m * n * k) - - if bool(beta): - gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n - flops += 2 * m * n - - multiplier = max(batch_count, num_groups) - gmem_bytes *= multiplier - flops *= multiplier - - return flops / gmem_bytes - -def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode - ): - # For functional testing, we prefer to run reference computing on device if any - reference_device_archs = ["100a", "103a"] - run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False - profiler_flags_for_verification = "device" if run_reference_on_device else "host" - - # beta values for L0 and L1 - # TODO: randomize beta values for wider coverage - beta_values = [0.5] - - is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"]) - - is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch - - if (mode == "functional_L0") and is_supported_arch: - problem_waves = [0.5, 1.25, 2.5] - - # - # Dense Gemm - # - - sm100_mma_data_type_general = [ - 'gemm_f16_f16_f16_f16_f16', - 'gemm_f16_f16_f16_void_f16', - #'gemm_f16_f16_f32_f16_f16', - 'tf32gemm_f32_f32_f32_f32_f32', - 'bf16gemm_f32_f32_f32_f32_f32', - ] - - exclude_archs = arch not in ("103a") - if exclude_archs: - sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8') - - sm100_mma_data_type_runtime_dtype = [ - 'gemm.*f4_f4_f32_f32_f32', - 'gemm.*f6_f6_f32_f32_f32', - 'gemm.*f8_f8_f32_f32_f32', - ] - - sm100_mma_cluster_size = [ - '8x1x1', - '4x4x1', '2x1x1', - '0x0x1' # dynamic cluster - ] - - # Restrict to two layouts to reduce L0 build and test time. - sm100_mma_layouts = [ - 'tnt', - 'ntn' - ] - - # regex list must be in kernel procedural name order - sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" - sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" - - sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" - sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" - - # - # Block Scale Gemm - # - - block_scaled_data_type = [ - # runtime datatypes - 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', - 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', - 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', - #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', - 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', - ] - - block_scaled_tile_k = ['x128_', 'x256_'] - - sm103_block_scaled_data_type = [ - 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', - 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', - ] - - sm103_block_scaled_tile_k = ['x768_'] - - block_scaled_cluster_size = [ - '4x4x1', '2x1x1', - '0x0x1' # dynamic cluster - ] - - block_scaled_layouts = ['tnt'] - # regex list must be in kernel procedural name order - block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - - sm103_block_scaled_prefetch_policy = ['tmapf'] - sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" - sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" - - if arch in ["100a", "100f"]: - kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ - f"({sm100_mma_filter_regex_2sm})|" \ - f"({sm100_mma_filter_regex_1sm_runtime})|" \ - f"({sm100_mma_filter_regex_2sm_runtime})|" \ - f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})" - elif arch in ["101a", "101f", "110a", "110f"]: - kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ - f"({sm100_mma_filter_regex_2sm})|" \ - f"({sm100_mma_filter_regex_1sm_runtime})|" \ - f"({sm100_mma_filter_regex_2sm_runtime})|" \ - f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})" - elif arch in ["103a"]: - kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ - f"({sm100_mma_filter_regex_2sm})|" \ - f"({sm100_mma_filter_regex_1sm_runtime})|" \ - f"({sm100_mma_filter_regex_2sm_runtime})|" \ - f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})|" \ - f"({sm103_block_scaled_filter_regex_1sm})|" \ - f"({sm103_block_scaled_filter_regex_2sm})" - elif arch in ["120a", "120f", "121a", "121f"]: - - # blockscaled sm120_mma kernels - blockscaled_sm120_mma_kernel_cta_tiles = [ - [ '128x128' ] - ] - - # Restrict to two layouts to reduce L0 build and test time. - blockscaled_sm120_mma_layouts = [ 'tn' ] - filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*" - - problem_waves = [0.5, 1.25, 2.5] - - kernel_filter = f"({filter_regex_blockscaled_sm120_mma})" - else: - error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f" - raise Exception(error_message) - - elif mode == "functional_L1": - sm100_mma_cluster_size = [ - '0x0x1' # dynamic cluster - ] - # Restrict to two layouts to reduce L1 build and test time. - sm100_mma_layouts = ['tnt', 'ntn'] - sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" - sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" - block_scaled_data_type = [ - 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', - 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', - 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2', - 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', - 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', - ] - - sm103_block_scaled_data_type = [ - 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', - 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', - ] - - block_scaled_cluster_size = ['0x0x1'] - block_scaled_layouts = ['tnt'] - - # regex list must be in kernel procedural name order - block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - - sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - - filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ - f"({sm100_mma_filter_regex_2sm})|" \ - f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})" \ - f"({sm103_block_scaled_filter_regex_1sm})|" \ - f"({sm103_block_scaled_filter_regex_2sm})" - # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times - sm120_mma_kernel_cta_tiles = [ - # h1688, s1688, i16832, i8816 - [ '256x128' ], - # d884, c1688, - [ '128x128' ], - # c1688, z884 - [ '128x64' ], - # gz884 - [ '64x64' ] - ] - - # sm120 MMA instruction shapes, planar complex type excluded as they are not required - sm120_mma_instruction_shapes = [ - [ 'h1688gemm_(?!planar_complex)', - 's1688gemm_f16', - 's1688gemm_bf16', - 's1688gemm_tf32', - 'i16832gemm', - 'i8816gemm' ], - [ 'd884gemm', 'c1688tf32gemm' ] , - [ 'c1688gemm', - 'z884gemm' ], - [ 'gz884gemm'] - ] - - # It's not pretty, but not sure why different instructions support different tile sizes. - filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*" - filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*" - filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*" - filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*" - - filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})" - - problem_waves = [0.5, 1.25, 2.5] - - if arch in ["120a", "120f", "121a", "121f"]: - kernel_filter = f"({filter_regex_sm120_mma})" - else: - kernel_filter = f"({filter_regex_sm100_mma})" - else: - raise ValueError() - - outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") - - audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv") - - audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") - - kernel_filter_re = re.compile(kernel_filter) - testcase_counter = 0 - kernels_emitted = 0 - kernels_total = 0 - - perf_json_list = [] - kernel_name_set = set() - - testlist_csv_fields = ["testcase", "metadata"] - testlist_csv_rows = [] - auditlist_csv_map = {} - auditlist_csv_params_map = {} - - kernel_features = {} - - for cc in manifest.operations[OperationKind.Gemm].keys(): - for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items(): - assert(len(operation_l) == 1) - kernels_total += 1 - if len(kernel_filter_re.findall(kernel_name)) == 0: - continue - # Only test f16 I/O void C kernels in void C kernel set - # Exception: Use void C kernels for more accurate perf testing - if '_void_' in kernel_name and 'perf_' not in mode: - if 'f16_f16_f16_void_f16' not in kernel_name : - continue - - kernels_emitted += 1 - kernel_name_set.add(kernel_name) - hashed_kernel_name = hash_cutlass_string(kernel_name) - operation = operation_l[0] - - dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0 - or operation.tile_description.cluster_shape[1] == 0) - - dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name - - runtime_input_datatypes = [None] - - if dynamic_datatype: - if "f4_f4" in kernel_name: - runtime_input_datatypes = [['e2m1','e2m1']] - elif "f4_f6" in kernel_name: - runtime_input_datatypes = [['e2m1','e3m2']] - elif "f4_f8" in kernel_name: - runtime_input_datatypes = [['e2m1','e4m3']] - - elif "f6_f4" in kernel_name: - runtime_input_datatypes = [['e3m2','e2m1']] - elif "f6_f6" in kernel_name: - runtime_input_datatypes = [['e3m2','e3m2']] - elif "f6_f8" in kernel_name: - runtime_input_datatypes = [['e3m2','e4m3']] - - elif "f8_f4" in kernel_name: - runtime_input_datatypes = [['e4m3','e2m1']] - elif "f8_f6" in kernel_name: - runtime_input_datatypes = [['e4m3','e3m2']] - elif "f8_f8" in kernel_name: - runtime_input_datatypes = [ - # mask out those not covered in statically encoded test cases - # ['e5m2','e4m3'], - # ['e4m3','e5m2'], - ['e4m3','e4m3'] - ] - - # block scaled kernels - elif "ue8m0xf4_ue8m0xf4" in kernel_name: - runtime_input_datatypes = [['e2m1','e2m1']] - elif "ue4m3xf4_ue4m3xf4" in kernel_name: - runtime_input_datatypes = [['e2m1','e2m1']] - elif "ue8m0xf4_ue8m0xf6" in kernel_name: - runtime_input_datatypes = [['e2m1','e2m3']] - elif "ue8m0xf4_ue8m0xf8" in kernel_name: - runtime_input_datatypes = [['e2m1','e4m3']] - - elif "ue8m0xf6_ue8m0xf4" in kernel_name: - runtime_input_datatypes = [['e2m3','e2m1']] - elif "ue8m0xf6_ue8m0xf6" in kernel_name: - runtime_input_datatypes = [['e2m3','e2m3']] - elif "ue8m0xf8_ue8m0xf4" in kernel_name: - runtime_input_datatypes = [['e4m3','e2m1']] - - elif "ue8m0xf8_ue8m0xf4" in kernel_name: - runtime_input_datatypes = [['e4m3','e2m1']] - elif "ue8m0xf8_ue8m0xf6" in kernel_name: - runtime_input_datatypes = [['e4m3','e2m3']] - elif "ue8m0xf8_ue8m0xf8" in kernel_name: - runtime_input_datatypes = [['e4m3','e4m3']] - - if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): - profiler_flags_for_verification = "host" - - # reduce L1 test runtime if reference kernel is not running on device. - if mode == "functional_L1" and profiler_flags_for_verification == "host" : - problem_waves = [0.5, 2.5] - - - if dynamic_cluster: - if mode == "functional_L0": - runtime_cluster_shapes = [[1,1,1], [2,2,1]] - else: - runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] - # reduce L1 test runtime if reference kernel is not running on device. - if profiler_flags_for_verification == "host": - runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]] - cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape - else: - runtime_cluster_shapes = [operation.tile_description.cluster_shape] - cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0]) - cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1]) - cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2]) - - alignment_a = operation.A.alignment - alignment_b = operation.B.alignment - alignment_c = operation.C.alignment - alignment_ab_max = max(alignment_a, alignment_b) - - layout3x = operation.layout_name_3x() - data_types = operation.datatype_name_3x() - - ctas_per_mma_instruction = 1 - if '_2sm' in kernel_name: - ctas_per_mma_instruction = 2 - valid_cluster_shapes = [] - - # Remove any cluster shapes that have cluster_m that is not divisible by 2 - for cs in runtime_cluster_shapes: - if cs[0] % 2 == 0: - valid_cluster_shapes.append(cs) - runtime_cluster_shapes = valid_cluster_shapes - - kernel_problem_waves = problem_waves - if mode == "functional_L0" or mode == "functional_L1": - # for functional testing, we want to perturb just a little from even shapes - # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not - # -16 ensures that we are TMA aligned even for FP8/Int8 - min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max - max_k = (cta_tile_shape_k*8) - alignment_ab_max - problem_shapes_k = [min_k, max_k] - sm_count = 16 - swizzle_sizes = [0] - # Larger k and less than half wave trigger streamk +separate reduction case to be generated - if 'stream_k' in kernel_name: - problem_shapes_k = [max_k, cta_tile_shape_k*32] - kernel_problem_waves = [0.125, 1.25, 2.5] - else: - raise ValueError - - if "void" in kernel_name: - beta_values = [0] - - alignment_shift_m = max(alignment_c, alignment_a) - alignment_shift_n = max(alignment_c, alignment_b) - - is_first_line = True - for index_waves, waves in enumerate(kernel_problem_waves): - for index_k, k in enumerate(problem_shapes_k): - for beta in beta_values: - for cluster_shape in runtime_cluster_shapes: - for runtime_input_datatype in runtime_input_datatypes: - for swizzle_size in swizzle_sizes: - grid_size = waves * sm_count - cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape) - if cluster_shape_m >= cluster_shape_n: - grid_m = cluster_shape_m - grid_n = grid_size / grid_m - grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1) - else: - grid_n = cluster_shape_n - grid_m = grid_size / grid_n - grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1) - - verification_required = False - if mode == "functional_L0" or mode == "functional_L1": - if '_void_' not in kernel_name: - verification_required = True - - m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max) - n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max) - k = int(k) - - # For functional testing, we want to perturb just a little from even shapes. - # Only do this if the perturbation does not cause one of the dimensions of the - # problem size to go to zero. This can occur for blockscaling kernels for which - # the alignment requirements for A and B can be quite large (e.g., 256). - if m > alignment_shift_m: - m -= alignment_shift_m - if n > alignment_shift_n: - n -= alignment_shift_n - - if '_n32t32_' in kernel_name: - continue - batch_count = 1 - if mode == "functional_L0" or mode == "functional_L1" : - if index_waves == 0 and index_k == 0 : - batch_count = 3 if mode == "functional_L0" else 5 - gemm_op = "gemm" - - grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) - num_groups = 1 - if grouped: - gemm_op = "grouped_gemm" - num_groups = 3 # small to limit test time in host block-scaled reference kernels - batch_count = 1 - elif "bstensorop" in kernel_name: - gemm_op = "block_scaled_gemm" - elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): - gemm_op = "blockwise_gemm" - - problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] - - assert m > 0 and n > 0 and k > 0 - - # Emit per-testcase metadata for perf testing usage, eventually in perf database - metadata_dict = { - "input_params": { - 'problem_size_category' : problem_size_category, - 'operation' : _getSubOperationType(operation), - 'datatype' : data_types, - 'layout' : layout3x, - 'm' : m, - 'n' : n, - 'k' : k, - 'beta' : beta, - 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups) - }, - "runtime_params": { - 'ctas_per_mma_instruction' : ctas_per_mma_instruction, - 'tilesize_m' : cta_tile_shape_m, - 'tilesize_n' : cta_tile_shape_n, - 'tilesize_k' : cta_tile_shape_k, - 'cluster_shape_m' : cluster_shape_m, - 'cluster_shape_n' : cluster_shape_n, - } - } - - cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m - cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n - cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k - - - if dynamic_datatype: - runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype) - metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a - metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b - - testcase_metadata = [ - f"cutlass_profiler --operation={gemm_op}" + - (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") + - f" --error-on-no-match --error-if-nothing-is-profiled" + - f" --kernels={kernel_name}" + - f" --m={str(m)}" + - f" --n={str(n)}" + - f" --k={str(k)}" + - (f" --num_groups={str(num_groups)}" if grouped else "") + - f" --cluster_m={str(cluster_shape_m)}" + - f" --cluster_n={str(cluster_shape_n)}" + - f" --cluster_k={str(cluster_shape_k)}" + - f" --cluster_m_fallback={str(cluster_m_fallback)}" + - f" --cluster_n_fallback={str(cluster_n_fallback)}" + - f" --cluster_k_fallback={str(cluster_k_fallback)}" + - f" --beta={str(beta)}" + - ("" if grouped else f" --batch_count={str(batch_count)}") + - f" --swizzle_size={str(swizzle_size)}" + - f" --verification-required={str(verification_required).lower()}" - ] \ - - output_dynamic_datatype = dynamic_datatype - if output_dynamic_datatype: - testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" + - f" --runtime_input_datatype_b={runtime_datatype_b}") - - testcase_metadata.append(json.dumps(metadata_dict)) - testlist_csv_rows.append(testcase_metadata) - testcase_counter += 1 - - alpha = 1.0 - - if dynamic_datatype: - hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b) - - # If kernel_name is new, initialize its feature set with defaults - if hashed_kernel_name not in kernel_features: - kernel_features[hashed_kernel_name] = { - "is_support_dynamic_cluster": False, - "is_support_dynamic_datatype": False, - } - - # Update features for the hashed kernel name - kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster - kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype - - if hashed_kernel_name not in auditlist_csv_params_map: - auditlist_csv_params_map[hashed_kernel_name] = [] - - audit_row_params = get_kernel_params( - operation, - hashed_kernel_name, - (cluster_shape_m, cluster_shape_n, cluster_shape_k), - (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback), - (m, n, k, batch_count), - alpha, beta, - dynamic_datatype, dynamic_cluster - ) - - auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params) - - if hashed_kernel_name not in auditlist_csv_map: - audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype) - auditlist_csv_map[hashed_kernel_name] = audit_row - - with open(outfile_name, 'w') as testlist_csv: - csv_writer = csv.writer(testlist_csv, delimiter=',') - csv_writer.writerow(testlist_csv_fields) - csv_writer.writerows(testlist_csv_rows) - - with open(audit_file_name, 'w') as auditlist_csv: - csv_writer = csv.writer(auditlist_csv, delimiter=',') - csv_writer.writerow(audit_csv_fields) - for hashed_kernel_name, row in auditlist_csv_map.items(): - # Append the dynamic features as "Y" or "N" - dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N" - dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N" - test_count = len(auditlist_csv_params_map[hashed_kernel_name]) - csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count]) - - with open(audit_file_params_name, 'w') as auditlist_csv: - csv_writer = csv.writer(auditlist_csv, delimiter=',') - csv_writer.writerow(audit_csv_runtime_fields) - for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1): - for i, row in enumerate(rows): - if i == 0: - csv_writer.writerow([kernel_index, hashed_kernel_name] + row) - else: - csv_writer.writerow(["", ""] + row) - - print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.") - - # Generate a newline separated list of kernel filters - assert(len(kernel_name_set) == kernels_emitted) - output_filter_enabled = True - if output_filter_enabled: - kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") - with open(kernel_filter_outfile_name, "w") as file: - kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set)) - for kernel_name in kernel_name_set: - file.write(kernel_name + "\n") - - # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together. - if mode == "functional_L0" or mode == "functional_L1": - # Sort the .csv file - outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") - with open(outfile_name) as file: - data = file.readlines() - data.sort() - with open(outfile_name, 'w') as file: - for i in range(len(data)): - file.write(data[i]) - # Sort the kernel list - kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") - with open(kernel_filter_outfile_name) as file: - data = file.readlines() - data.sort() - with open(kernel_filter_outfile_name, 'w') as file: - for i in range(len(data)): - file.write(data[i]) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py deleted file mode 100644 index 0d2449e769303b738212cdcd896c9f2793ca2632..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py +++ /dev/null @@ -1,1613 +0,0 @@ - -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting GEMM kernels -""" - -import collections -import enum -import functools -import logging -import operator -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - -_LOGGER = logging.getLogger(__name__) - -################################################################################################### -# -# Data structure modeling a GEMM operation -# -################################################################################################### - -# -class GemmOperation: - # - def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, - kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, - tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False, - ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None, - ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None): - - kinds_3x = { - GemmKind.Universal3x, - GemmKind.SparseUniversal3x, - GemmKind.BlockScaledUniversal3x, - GemmKind.GroupedUniversal3x, - GemmKind.GroupedBlockScaledUniversal3x, - GemmKind.BlockwiseUniversal3x, - GemmKind.GroupedBlockwiseUniversal3x, - } - self.is_3x = gemm_kind in kinds_3x - self.prefix = "3x" if self.is_3x else "" - self.operation_kind = OperationKind.Gemm - self.arch = arch - self.tile_description = tile_description - self.gemm_kind = gemm_kind - self.A = A - self.B = B - self.C = C - self.D = D - - if is_block_scaled(gemm_kind): - self.ScaleFactorA = ScaleFactorA - self.ScaleFactorB = ScaleFactorB - self.ScaleFactorD = ScaleFactorD["tensor"] - self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] - - if is_blockwise(gemm_kind): - self.ScaleFactorMVecSize = ScaleFactorMVecSize - self.ScaleFactorNVecSize = ScaleFactorNVecSize - self.ScaleFactorKVecSize = ScaleFactorKVecSize - - if self.D == None: - self.D = self.C - - if not self.is_3x: - assert(kernel_schedule == KernelScheduleType.ScheduleAuto) - assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) - self.kernel_schedule = kernel_schedule - self.epilogue_schedule = epilogue_schedule - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - - if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination: - self.epilogue_functor = EpilogueFunctor3x.LinearCombination - - self.swizzling_functor = swizzling_functor - self.tile_scheduler = tile_scheduler - - # Only enable mixed input mode and mixed input shuffle for Hopper - self.mixed_input_mode = None - if self.is_mixed_input() and self.arch >= 90 and self.arch < 100: - self.mixed_input_mode = mixed_input_mode - self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def is_planar_complex(self): - return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and', - MathOperation.multiply_add_fast_accum: 'fastaccum', - } - - tensor_ops = [ - OpcodeClass.TensorOp, - OpcodeClass.WmmaTensorOp, - OpcodeClass.SparseTensorOp, - OpcodeClass.BlockScaledTensorOp, - ] - - is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops - - if is_tensor_op: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else "" - - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - short_math_name = self.short_math_name() if not self.is_3x else "" - - return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) - - # Generates a string representing the MMA instruction. - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - element_sfa = "" - element_sfb = "" - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.is_mixed_input(): - extended_name = "${core_name}_${element_a}_${element_b}" - if self.C.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_" + extended_name - elif is_blockwise(self.gemm_kind): - extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}" - element_sfa = DataTypeNames[self.accumulator_type()] - element_sfb = DataTypeNames[self.accumulator_type()] - else: - extended_name = "${core_name}" - if self.C.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_" + extended_name - if self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name += "_${element_a}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_sfa' : element_sfa, - 'element_b': DataTypeNames[self.B.element], - 'element_sfb' : element_sfb, - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def mixed_input_mode_name(self): - mode_name_mapping = { - MixedInputMode.ConvertOnly: "_cvt", - MixedInputMode.ScaleOnly: "_scl", - MixedInputMode.ScaleWithZeroPoint: "_sclzr" - } - mode_name = mode_name_mapping.get(self.mixed_input_mode, "") - if self.mixed_input_shuffle: - mode_name = mode_name + "_shfl" - return mode_name - - def extended_name_3x(self): - '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' - extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( - element_a = DataTypeNames[self.A.element], - element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.accumulator_type()], - element_c = DataTypeNames[self.C.element], - element_d = DataTypeNames[self.D.element], - core_name = self.core_name()) - - if is_block_scaled(self.gemm_kind): - d_type_names = DataTypeNames[self.D.element] - - if self.ScaleFactorD.element != DataType.void: - d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names - - extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( - element_sfa = DataTypeNames[self.ScaleFactorA], - element_a = DataTypeNames[self.A.element], - element_sfb = DataTypeNames[self.ScaleFactorB], - element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.accumulator_type()], - element_c = DataTypeNames[self.C.element], - element_d = d_type_names, - core_name = self.core_name()) - - if is_blockwise(self.gemm_kind): - d_type_names = DataTypeNames[self.D.element] - - extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( - element_sfa = DataTypeNames[self.accumulator_type()], - element_a = DataTypeNames[self.A.element], - element_sfb = DataTypeNames[self.accumulator_type()], - element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.accumulator_type()], - element_c = DataTypeNames[self.C.element], - element_d = d_type_names, - sfvec_m_size = self.ScaleFactorMVecSize, - sfvec_n_size = self.ScaleFactorNVecSize, - sfvec_k_size = self.ScaleFactorKVecSize, - core_name = self.core_name()) - - if self.mixed_input_mode != None: - extended_name = extended_name + self.mixed_input_mode_name() - return extended_name - - def datatype_name_3x(self): - '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' - datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( - element_a = DataTypeNames[self.A.element], - element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.accumulator_type()], - element_c = DataTypeNames[self.C.element], - element_d = DataTypeNames[self.D.element]) - return datatype_name - - # Generates a short string representing the AB layout tags (e.g. nt or tn) - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] - ) - return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - - # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) - def layout_name_3x(self): - if self.is_complex() or self.is_planar_complex(): - return "{}{}{}".format( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], - ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) - else: - return "{}{}{}".format( - ShortLayoutTypeNames[self.A.layout], - ShortLayoutTypeNames[self.B.layout], - ShortLayoutTypeNames[self.C.layout]) - - # Generates a short string representing underlying kernel schedule type - def kernel_schedule_name_3x(self): - return KernelScheduleSuffixes[self.kernel_schedule] - - # Generates a short string representing underlying epilogue schedule type - def epilogue_schedule_name_3x(self): - - if is_block_scaled(self.gemm_kind): - if self.ScaleFactorD.element != DataType.void: - return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] - - return EpilogueScheduleSuffixes[self.epilogue_schedule] - - # Generate a short string representing the operation class - def opcode_class_name(self): - return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - def get_collective_tile_shape(self): - """ - Get the tile shape passed to the collective builder. - On Blackwell, this is different than the operation.tile_description.tile_shape. - """ - is_sm100_kernel = (self.arch == 100 or self.arch == 103) - if not is_sm100_kernel: - return self.tile_description.tile_shape - - opcode_class_main = self.tile_description.math_instruction.opcode_class - instruction_shape = self.tile_description.math_instruction.instruction_shape - tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape - if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]: - tile_shape_m = instruction_shape[0] - tile_shape_n = instruction_shape[1] - return (tile_shape_m, tile_shape_n, tile_shape_k) - - # Generates the full kernel function name - def procedural_name(self): - return self._procedural_name - - @functools.cached_property - def _procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - if self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}" - tile_shape = self.get_collective_tile_shape() - return kernel_name_template.format( - p = self.prefix, - ar = self.arch, - op = opcode_class_name, - ex = self.extended_name_3x(), - ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "", - cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), - l = self.tile_description.stages, - s = self.layout_name_3x(), - al = str(max(self.A.alignment, self.B.alignment)), - t = TileSchedulerSuffixes[self.tile_scheduler], - k = self.kernel_schedule_name_3x(), - e = self.epilogue_schedule_name_3x()) - else: - threadblock = self.tile_description.procedural_name() - return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( - p = self.prefix, - op = opcode_class_name, - ex = self.extended_name(), - tb = threadblock, - l = self.layout_name(), - a = str(max(self.A.alignment, self.B.alignment))) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() - - def __hash__(self): - return hash(self.configuration_name()) - - def __eq__(self, other): - return self.configuration_name() == other.configuration_name() - -################################################################################################### -# -# Data structure modeling a grouped GEMM operation -# -################################################################################################### - -# -class GroupedGemmOperation(GemmOperation): - # - def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ - scheduler_mode = GroupScheduleMode.Device): - super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor, swizzling_functor) - - self.scheduler_mode = scheduler_mode - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - base = super().procedural_name() - return SubstituteTemplate( - base + "_schedule${schedule}", - { - 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode] - }) - - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -# -class EmitGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [] - self.gemm_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::Gemm< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - false, - ${math_operation} - ${residual} - >; -""" - self.gemm_complex_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${transform_a}, - ${transform_b}, - ${math_operation} - ${residual} - >; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_complex_template if operation.is_complex() else self.gemm_template - - return SubstituteTemplate(template, values) - -################################################################################################### - -class EmitSparseGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [] - self.gemm_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - false, - ${math_operation} - ${residual} - >; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_template - - return SubstituteTemplate(template, values) - -################################################################################################### - - -# -class EmitGemmUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/device/gemm.h", - "cutlass/gemm/device/gemm_universal_adapter.h", - "cutlass/gemm/kernel/default_gemm_universal.h", - ] - self.builtin_epilogue_functor_template = """ - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - > -""" - self.gemm_template = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmUniversal< - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - self.gemm_template_interleaved = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - transpose_layouts = { - LayoutType.ColumnMajor: LayoutType.RowMajor, - LayoutType.RowMajor: LayoutType.ColumnMajor - } - - if operation.A.layout in transpose_layouts.keys() and \ - operation.B.layout in transpose_layouts.keys() and \ - operation.C.layout in transpose_layouts.keys(): - - instance_layout_A = transpose_layouts[operation.A.layout] - instance_layout_B = transpose_layouts[operation.B.layout] - instance_layout_C = transpose_layouts[operation.C.layout] - - gemm_template = self.gemm_template - else: - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) - - gemm_template = self.gemm_template_interleaved - # - - # Support built-in epilogue functors or user-defined functions - if isinstance(operation.epilogue_functor, enum.Enum): - - epilogue_vector_length = \ - min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] - - values = { - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - } - epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) - else: - epilogue_functor = self.epilogue_functor.emit_declaration() - # - - values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_functor': epilogue_functor, - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] - } - - return SubstituteTemplate(gemm_template, values) - - -################################################################################################### - -class EmitGemmUniversal3xInstance: - ''' Responsible for emitting a CUTLASS 3.x template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/gemm/gemm.h", - "cutlass/numeric_types.h", - "cutlass/gemm/kernel/gemm_universal.hpp", - "cutlass/gemm/collective/collective_builder.hpp", - "cutlass/epilogue/collective/collective_builder.hpp", - "cutlass/detail/blockwise_scale_layout.hpp", - ] - self.builtin_epilogue_functor_template = \ -"""${epilogue_functor}< - ${element_d}, - ${element_epilogue}, - ${element_c}, - ${element_epilogue} - >""" - - self.gemm_template = """ - -using ${operation_name}_epilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ${arch}, ${opcode_class_epi}, - cute::Shape, - cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, - ${epi_tile_mn}, - ${element_accumulator}, ${element_epilogue}, - ${element_c}, ${layout_c}, ${align_c}, - ${element_d}, ${layout_d}, ${align_d}, - ${epilogue_schedule}, - ${epilogue_functor} - >::CollectiveOp; - -${mixed_dtype_prepare_code} -${blockwise_prepare_code} - -using ${operation_name}_mainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ${arch}, ${opcode_class_main}, - ${element_a}, ${layout_a}, ${align_a}, - ${element_b}, ${layout_b}, ${align_b}, - ${element_accumulator}, - cute::Shape, - cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, - ${stages}, - ${kernel_schedule} - >::CollectiveOp; - -// Gemm operator ${operation_name} -using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< - ${problem_shape}, - ${operation_name}_mainloop, - ${operation_name}_epilogue, - ${tile_scheduler}>; - -// Define named type -struct ${operation_name} : - public ${operation_name}_base { }; - -""" - # - def instance_template(self): - return """ -${compile_guard_start} - { - using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; - manifest.append( - new ${gemm_kind}("${operation_name}")); - } -${compile_guard_end} -""" - - - def emit_block_scale_epilogue_functor(self, operation): - block_scaled_template = """ - ${epilogue_functor}< - ${epi_vs}, - ${element_d}, - ${element_accumulator}, - ${element_sfd}, - ${layout_sfd}, - ${element_c}, - ${element_scalar} - > - """ - block_scaled_values = { - 'epi_vs' : str(operation.ScaleFactorVectorSize), - 'element_d': str(DataTypeTag[operation.D.element]), - 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]), - 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout], - 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor], - 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]), - 'element_scalar': str(DataTypeTag[operation.accumulator_type()]), - 'element_c': str(DataTypeTag[operation.C.element]), - } - return SubstituteTemplate(block_scaled_template, block_scaled_values) - - - @staticmethod - def pointerize_if_grouped(operation, layout): - return layout if not is_grouped(operation.gemm_kind) else layout + "* " - - @staticmethod - def transform_layout_A_if_blockwise(operation, layout): - layout_sfa = f"{operation.procedural_name()}_LayoutSFA" - layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* " - return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>" - - @staticmethod - def transform_layout_B_if_blockwise(operation, layout): - layout_sfb = f"{operation.procedural_name()}_LayoutSFB" - layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* " - return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>" - - @staticmethod - def problem_shape(operation): - gemm_shape_type = "cute::Shape" - grouped_gemm_shape_type = "cute::Shape" - grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" - - return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type - - def emit(self, operation): - _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") - _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name()) - _LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape)) - _LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count)) - - opcode_class_main = operation.tile_description.math_instruction.opcode_class - opcode_class_epi = opcode_class_main - - tile_shape = operation.tile_description.tile_shape - instruction_shape = operation.tile_description.math_instruction.instruction_shape - cluster_m = operation.tile_description.cluster_shape[0] - cluster_n = operation.tile_description.cluster_shape[1] - cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] - tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape() - - # stage count set to zero indicates builder automatic stage selection - if operation.tile_description.stages > 0: - stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" - elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100: - stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>" - else: - stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" - - epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" - - instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ - (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) - - # 3.0 profiler integration only supports trivial epilogues for now - epilogue_vector_length = 1 - - # Support built-in epilogue functors or user-defined functions - if isinstance(operation.epilogue_functor, enum.Enum): - values = { - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], - } - epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) - - if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: - epilogue_functor = self.emit_block_scale_epilogue_functor(operation) - - - else: - epilogue_functor = self.epilogue_functor.emit_declaration() - - if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: - epilogue_functor = self.emit_block_scale_epilogue_functor(operation) - - # - # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. - element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" - element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" - epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] - - if opcode_class_main == OpcodeClass.BlockScaledTensorOp: - grouped = is_grouped(operation.gemm_kind) - if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): - epi_tile_mn = "cute::Shape" - if is_tma_epilogue(operation.epilogue_schedule): - epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] - if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): - epi_tile_mn = "cute::Shape" - if is_tma_epilogue(operation.epilogue_schedule): - epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] - # SM103 FP4 Ultra - is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) - ] - is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), - to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) - ] - if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: - epi_tile_mn = "cute::Shape" - if is_tma_epilogue(operation.epilogue_schedule): - epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] - if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: - epi_tile_mn = "cute::Shape" - if is_tma_epilogue(operation.epilogue_schedule): - epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] - - element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' - element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' - - alignment_c = get_tma_alignment(operation.C.element) \ - if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ - else operation.C.alignment - alignment_d = get_tma_alignment(operation.D.element) \ - if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ - else operation.D.alignment - - operation_name_str = operation.procedural_name() - layout_a_str = LayoutTag[instance_layout_A] - layout_b_str = LayoutTag[instance_layout_B] - mixed_dtype_prepare_code = "" - if operation.mixed_input_mode != None: - A_dtype = operation.A.element - B_dtype = operation.B.element - A_dtype_bits = DataTypeSize[A_dtype] - B_dtype_bits = DataTypeSize[B_dtype] - is_A_dtype_narrow = A_dtype_bits < B_dtype_bits - if is_A_dtype_narrow: - narrow_dtype, wide_dtype = (A_dtype, B_dtype) - narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) - else: - narrow_dtype, wide_dtype = (B_dtype, A_dtype) - narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) - - narrow_tag = DataTypeTag[narrow_dtype] - wide_tag = DataTypeTag[wide_dtype] - scale_tag = DataTypeTag[wide_dtype] - zero_tag = DataTypeTag[wide_dtype] - - do_shuffle = False - value_shuffle_str = "" - if narrow_dtype_bits == 4 and wide_dtype_bits == 16: - value_shuffle_str = "cute::Layout, cute::Stride>" - do_shuffle = True - if narrow_dtype_bits == 8 and wide_dtype_bits == 16: - value_shuffle_str = "cute::Layout, cute::Stride>" - do_shuffle = True - do_shuffle = operation.mixed_input_shuffle and do_shuffle - - if do_shuffle: - if is_A_dtype_narrow: - stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" - layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" - else: - stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" - layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" - # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and - # layout_{a, b}_str are to prevent errors in Windows platform unity build - mixed_dtype_prepare_code = f""" -using {operation_name_str}_StrideNarrow = {stride_narrow_str}; -using {operation_name_str}_ValueShuffle = {value_shuffle_str}; -static constexpr int {operation_name_str}_NumShuffleAtoms = 1; -using {operation_name_str}_MmaAtomShape = cute::Layout>>; -using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>()); -using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout, {operation_name_str}_StrideNarrow>{{}})); - """ - - mixed_input_modes_to_element = { - MixedInputMode.ConvertOnly: narrow_tag, - MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", - MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>" - } - narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag) - - if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): - narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" - - if is_A_dtype_narrow: - element_a = narrow_element - else: - element_b = narrow_element - - blockwise_prepare_code = "" - if is_blockwise(operation.gemm_kind): - sfm_vec_size = operation.ScaleFactorMVecSize - sfn_vec_size = operation.ScaleFactorNVecSize - sfk_vec_size = operation.ScaleFactorKVecSize - blockwise_prepare_code = f""" -using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; -using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA()); -using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB()); - """ - - values = { - 'operation_name': operation_name_str, - 'operation_suffix': self.operation_suffix, - 'problem_shape': self.problem_shape(operation), - 'element_a': element_a, - 'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)), - 'element_b': element_b, - 'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]), - 'element_d': DataTypeTag[operation.D.element], - 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]), - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class_main': OpcodeClassTag[opcode_class_main], - 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'tile_shape_m': str(tile_shape_m), - 'tile_shape_n': str(tile_shape_n), - 'tile_shape_k': str(tile_shape_k), - 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int", - 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int", - 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int", - 'instruction_shape_m': str(instruction_shape[0]), - 'instruction_shape_n': str(instruction_shape[1]), - 'instruction_shape_k': str(instruction_shape[2]), - 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), - 'epilogue_schedule' : str(epilogue_schedule_type), - 'epi_tile_mn' : epi_tile_mn, - 'epilogue_functor': epilogue_functor, - 'stages': stage_count_string, - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'align_c': str(alignment_c), - 'align_d': str(alignment_d), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), - 'mixed_dtype_prepare_code': mixed_dtype_prepare_code, - 'blockwise_prepare_code' : blockwise_prepare_code - } - - return SubstituteTemplate(self.gemm_template, values) - -################################################################################################### - -# -class EmitGemmPlanarComplexInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [] - self.template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, - ${element_c}, cutlass::layout::RowMajor, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombinationPlanarComplex< - ${element_c}, - ${alignment_c}, - ${element_accumulator}, - ${element_epilogue} - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator} - >::GemmKernel; - - struct ${operation_name} : - public Operation_${operation_name} { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) - -################################################################################################### - -# -class EmitGemmPlanarComplexArrayInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [] - self.template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, - ${element_c}, cutlass::layout::RowMajor, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombinationPlanarComplex< - ${element_c}, - ${alignment_c}, - ${element_accumulator}, - ${element_epilogue} - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator} - >::GemmArrayKernel; - - struct ${operation_name} : public Operation_${operation_name} { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) - -################################################################################################### - -# -class EmitGemmGroupedInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix = ''): - self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/device/gemm.h", - "cutlass/gemm/kernel/gemm_grouped.h", - "cutlass/gemm/kernel/default_gemm_grouped.h", - "cutlass/gemm/device/gemm_grouped.h" - ] - self.builtin_epilogue_functor_template = \ -"""${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >""" - - self.gemm_template = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmGrouped< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${scheduler_mode}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmGrouped<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - transpose_layouts = { - LayoutType.ColumnMajor: LayoutType.RowMajor, - LayoutType.RowMajor: LayoutType.ColumnMajor - } - - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) - # - - # Support built-in epilogue functors or user-defined functions - if isinstance(operation.epilogue_functor, enum.Enum): - - epilogue_vector_length = \ - min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] - - values = { - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - } - epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) - else: - epilogue_functor = self.epilogue_functor.emit_declaration() - # - - values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_functor': epilogue_functor, - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] - } - - return SubstituteTemplate(self.gemm_template, values) - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitGemmConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - GemmKind.Gemm: EmitGemmInstance, - GemmKind.Sparse: EmitSparseGemmInstance, - GemmKind.Universal: EmitGemmUniversalInstance, - GemmKind.Universal3x: EmitGemmUniversal3xInstance, - GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance, - GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, - GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, - GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, - GemmKind.Grouped: EmitGemmGroupedInstance, - GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance, - GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, - GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance, - GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance, - } - - self.gemm_kind_wrappers = { - GemmKind.Gemm: 'GemmOperation', - GemmKind.Sparse: 'GemmSparseOperation', - GemmKind.Universal: 'GemmUniversalOperation', - GemmKind.Universal3x: 'GemmUniversal3xOperation', - GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation', - GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', - GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', - GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', - GemmKind.Grouped: 'GemmGroupedOperation', - GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation', - GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', - GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation', - GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation', - } - - self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" - - self.separator = """ -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.header_template = """ -/* - Generated by gemm_operation.py - Do not edit. -*/ -""" - - self.initialize_function_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_${configuration_name}(Manifest &manifest) { - -""" - self.epilogue_template = """ - -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def __enter__(self): - _LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__") - _LOGGER.debug("*** configuration_path (file to write): " + - str(self.configuration_path)) - - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - self.configuration_file.write(self.separator) - - self.includes = collections.OrderedDict([ - ("cutlass/cutlass.h", None), - ("cutlass/library/library.h", None), - ("cutlass/library/manifest.h", None), - ("library_internal.h", None), - ("gemm_operation.h", None), - ("gemm_operation_3x.hpp", None), - ("grouped_gemm_operation_3x.hpp", None), - ("sparse_gemm_operation_3x.hpp", None), - ("block_scaled_gemm_operation_3x.hpp", None), - ("blockwise_gemm_operation_3x.hpp", None), - ("cutlass/arch/wmma.h", None), - ("cutlass/numeric_types.h", None) - ]) - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") - _LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind)) - - emitter = self.instance_emitter[operation.gemm_kind]() - - for incl in emitter.includes: - self.includes[incl] = None - - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write includes - for incl, _ in self.includes.items(): - include_statement = "#include \"%s\"\n" % incl - self.configuration_file.write(include_statement) - - self.configuration_file.write(self.separator) - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - -################################################################################################### -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py deleted file mode 100644 index 063e8fb1caa6626e8ba099133fee4dd3dc115e40..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py +++ /dev/null @@ -1,10962 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for enumerating CUTLASS library kernels -""" - -import argparse -import enum -from itertools import chain, product -import logging -import os.path -import shutil -import sys -import copy -from typing import Any, Dict, Optional, Sequence, Tuple - -_LOGGER = logging.getLogger(__name__) - -def logging_prefix(indent_level: int = 0) -> str: - """String prefix for start of each debug log entry""" - prefix = '*** ' - indent = ' ' - return f"{prefix}{indent_level * indent}" - -def log_debug_line(line: str, indent_level: int = 0) -> None: - """Log one line of debug output""" - prefix = logging_prefix(indent_level) - _LOGGER.debug(prefix + line) - -# Certain usecases of cutlass_library nearly always prefer to run as scripts with -# relative imports, rather than via an installed Python package. An example of this -# is using CUTLASS's CMake system to generate a library of kernels to be profiled. -# To make it easy to use these use cases when an existing installation of cutlass_library -# exists, this global flag can be set to true (via command-line arguments) to ensure -# that package-based installations are not used. - -# Create a temporary argument parser to check only for the availability of the -# --disable-cutlass-package-imports argument, which controls whether package-based -# imports are disabled. -def _add_package_disablement_flag(argparser): - argparser.add_argument("--disable-cutlass-package-imports", action='store_true', required=False, - help="Disable use of cutlass_library from Python package") - -_parser = argparse.ArgumentParser() -_add_package_disablement_flag(_parser) -_args, _ = _parser.parse_known_args() - -# Add `CUTLASS_IGNORE_PACKAGE` to `builtins` so that it is visible for gating future -# imports without requiring importing another module. Ideally, we would just place this -# as a global variable in a module to that could be imported and checked (e.g., -# utils.CUTLASS_IGNORE_PACKAGE). However, this raises the issue of determining -# where this module should be sourced (from the cutlass_library package or from -# a relative import), which is the problem this variable is being used to solve in the -# first place. -import builtins -builtins.CUTLASS_IGNORE_PACKAGE = _args.disable_cutlass_package_imports - -try: - if CUTLASS_IGNORE_PACKAGE: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * - from cutlass_library.manifest import * - from cutlass_library.heuristics import * - from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist -except ImportError: - from library import * - from manifest import * - from heuristics import * - from emit_kernel_listing import emit_gemm_kernel_testlist -################################################################################################### - -# -def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): - - # by default, use the latest CUDA Toolkit version - cuda_version = [11, 0, 132] - - # Update cuda_version based on parsed string - if semantic_ver_string != '': - for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): - if i < len(cuda_version): - cuda_version[i] = x - else: - cuda_version.append(x) - return cuda_version >= [major, minor, patch] - -# From cuda 13.0, Thor SM is renumbered from 101 to 110 -def ThorSMRenumbering(cuda_version): - return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101 - -################################################################################################### -################################################################################################### - -# -def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): - ''' Helper to compute the maximum alignment of the epilogue ''' - - def product(X, identity = 1): - result = identity - for item in X: - result *= item - return result - - elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps - return min(max_alignment, elements_per_thread) - -def DefaultSwizzlingFunctor(): - return SwizzlingFunctor.Identity8 - # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` - -# -def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = DefaultSwizzlingFunctor()): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] - - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - # If alignment is a tuple or a list, then we have different alignments for A and B - alignment_a = alignment if isinstance(alignment, int) else alignment[0] - alignment_b = alignment if isinstance(alignment, int) else alignment[1] - alignment_c = min(8, alignment_a) if isinstance(alignment, int) else alignment[2] - - A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) - C = TensorDescription(element_c, layout[2], alignment_c) - - new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts -def CreateGemmUniversal3xOperator( - manifest, layouts, tile_descriptions, data_types, - schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], - complex_transforms=None, - epilogue_functor=EpilogueFunctor.LinearCombination, - swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Default], - gemm_kind=GemmKind.Universal3x): - - if type(data_types) is dict: - data_types = [data_types] - - for s in schedules: - assert(len(s) == 2) - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - if len(tile_descriptions) == 0: - return operations - tile_descriptions = [tile_descriptions[0]] - - combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) - for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: - kernel_schedule, epilogue_schedule = schedules - A = TensorDescription( - data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) - B = TensorDescription( - data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) - - C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) - D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) - - gemm_op_extra_args = {} - element_compute = data_type.get("epi_type", data_type["acc_type"]) - - if "sf_type" in data_type: - gemm_op_extra_args["ScaleFactorA"] = data_type["sf_type"] - gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] - gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), - "vector_size" : data_type["sfd_type"]["vector_size"]} - assert is_block_scaled(gemm_kind) - - if tile_description.explicit_vector_sizes != None: - assert len(tile_description.explicit_vector_sizes) == 3 - gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0] - gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1] - gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2] - assert is_blockwise(gemm_kind) - else: - assert not is_blockwise(gemm_kind) - - A_dtype = data_type["a_type"] - B_dtype = data_type["b_type"] - A_dtype_bits = DataTypeSize[A_dtype] - B_dtype_bits = DataTypeSize[B_dtype] - is_A_dtype_narrow = A_dtype_bits < B_dtype_bits - if is_A_dtype_narrow: - narrow_dtype, wide_dtype = (A_dtype, B_dtype) - narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) - else: - narrow_dtype, wide_dtype = (B_dtype, A_dtype) - narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) - - mixed_input_modes = [None] - if narrow_dtype_bits != wide_dtype_bits: - if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): - mixed_input_modes = [MixedInputMode.ScaleOnly] - else: - mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint] - - mixed_input_shuffle_options = [False] - if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8): - mixed_input_shuffle_options = [False, True] - - for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options): - operation = GemmOperation( - gemm_kind, tile_description.minimum_compute_capability, - tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule, tile_scheduler, - mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args) - manifest.append(operation) - operations.append(operation) - - return operations - -# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts -def CreateSparseGemmUniversal3xOperator( - manifest, layouts, tile_descriptions, data_types, - schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], - complex_transforms=None, - epilogue_functor=EpilogueFunctor.LinearCombination, - swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Default]): - - if type(data_types) is dict: - data_types = [data_types] - - for s in schedules: - assert(len(s) == 2) - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0]] - - combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) - for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: - kernel_schedule, epilogue_schedule = schedules - A = TensorDescription( - data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) - B = TensorDescription( - data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) - - # Currently assume tensor C/D have same layout requirement. - C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) - D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) - - element_compute = data_type.get("epi_type", data_type["acc_type"]) - - operation = GemmOperation( - GemmKind.SparseUniversal3x, tile_description.minimum_compute_capability, - tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule, tile_scheduler) - - manifest.append(operation) - operations.append(operation) - - return operations - -# -def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] - - element_a, element_b, element_c, element_epilogue = data_type - - gemm_kinds = [GemmKind.Sparse] - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) - C = TensorDescription(element_c, layout[2], alignment_c) - - new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# -def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] - - element_a, element_b, element_c, element_epilogue = data_type - - gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for gemm_kind in gemm_kinds: - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) - C = TensorDescription(element_c, layout[2], alignment_c) - - manifest.append(GemmOperation(gemm_kind, \ - tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue)) - return - -# -def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] - - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) - C = TensorDescription(element_c, layout[2], alignment_c) - - new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# -def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_type, \ - alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - element_a, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for fill_mode in fill_modes: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - - # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation - complex_transform = ComplexTransform.none - - # HERK supported layouts (RowMajor + conj, ColumnMajor) - if blas_mode == BlasMode.hermitian and layout[0] == LayoutType.RowMajor: - complex_transform = ComplexTransform.conj - - alignment_c = 1 # Alignment only applies to A in SYRK - - A = TensorDescription(element_a, layout[0], alignment, complex_transform) - C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c) - - # Rank-K update - new_operation = RankKOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) - - manifest.append(new_operation) - operations.append(new_operation) - - # Rank-2K update - new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# -def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, data_type, \ - alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none),] - - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for side_mode in side_modes: - for fill_mode in fill_modes: - for diag_type in diag_types: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type, - alignment, complex_transform) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - new_operation = TrmmOperation(TrmmKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# -def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, data_type, \ - alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for side_mode in side_modes: - for fill_mode in fill_modes: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - - # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation - complex_transform = ComplexTransform.none - - alignment_a = 1 # No vectorized access for the triangular matrix - alignment_c = min(8, alignment) - - A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode) - # tensor A and B have same data type and layout - B = TensorDescription(element_b, layout[0], alignment) - C = TensorDescription(element_c, layout[1], alignment_c) - - # SYMM/HEMM update - new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) - - manifest.append(new_operation) - operations.append(new_operation) - - # SYMM/HEMM update - new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -########################################################################################################### -# ConvolutionOperator support variations -# ____________________________________________________________________ -# ConvolutionalOperator | Analytic | Optimized -# ____________________________________________________________________ -# | Fprop | (strided) | (strided) -# | Dgrad | (strided, unity*) | (strided, unity) -# | Wgrad | (strided) | (strided) -# ____________________________________________________________________ -# -# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low -########################################################################################################### -# Convolution for 2D operations -def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - - element_a, element_b, element_c, element_epilogue = data_type - - # one exceptional case - - # iterator algorithm (analytic and optimized) - iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] - - # by default, only generate the largest tile size, largest alignment, and optimized iterator - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - iterator_algorithms = [IteratorAlgorithm.Optimized] - - operations = [] - - for tile in tile_descriptions: - for alignment in alignment_constraints: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor - - # - # Conv2d Fprop - # - if ConvKind.Fprop in conv_kinds: - - # Strided support for Analytic and Optimized Fprop - for iterator_algorithm in iterator_algorithms: - new_operations = [ - # None grouped kernel - Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_), - ] - - # Instance group conv kernel - if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \ - tile.minimum_compute_capability >= 80: - # SingleGroup kernel - new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) - - # Analytic iterator supports MultipleGroup mode - if iterator_algorithm == IteratorAlgorithm.Analytic: - new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) - - for new_operation in new_operations: - manifest.append(new_operation) - operations.append(new_operation) - - # - # Conv2d Dgrad - # - if ConvKind.Dgrad in conv_kinds: - - # Unity stride for Analytic and Optimized Dgrad - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - # Strided support for Analytic Dgrad - # strided dgrad uses a special threadblock swizzle - # note that SwizzlingFunctor.StridedDgradHorizontal might be - # better for problem sizes with large activation channel count - swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 - - if IteratorAlgorithm.Analytic in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) - - manifest.append(new_operation) - operations.append(new_operation) - - # Strided support for Optimized Dgrad - if IteratorAlgorithm.Optimized in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) - - manifest.append(new_operation) - operations.append(new_operation) - - # - # Conv2d Wgrad - # - if ConvKind.Wgrad in conv_kinds: - - # Strided support for Analytic and Optimized Wgrad - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# Convolution for 2D operations specialized for few channels -def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - - element_a, element_b, element_c, element_epilogue = data_type - - # one exceptional case - - # iterator algorithm (analytic and optimized) - iterator_algorithms = [IteratorAlgorithm.FixedChannels,] - - # by default, only generate the largest tile size, largest alignment, and optimized iterator - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - channel_counts = [channel_counts[0],] - - operations = [] - - - - for tile in tile_descriptions: - for channel_count in channel_counts: - - alignment_c = EpilogueAlignment(channel_count, tile) - - A = TensorDescription(element_a, layout[0], channel_count) - B = TensorDescription(element_b, layout[1], channel_count) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor - - # - # Conv2d Fprop - # - if ConvKind.Fprop in conv_kinds: - - # Strided support for Analytic and Optimized Fprop - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# Convolution for 2D operations specialized for few channels -def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - - element_a, element_b, element_c, element_epilogue = data_type - - # one exceptional case - - # iterator algorithm (analytic and optimized) - iterator_algorithms = [IteratorAlgorithm.FewChannels,] - - # by default, only generate the largest tile size, largest alignment, and optimized iterator - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - channel_counts = [channel_counts[0],] - - operations = [] - - for tile in tile_descriptions: - for channel_count in channel_counts: - - alignment_c = EpilogueAlignment(channel_count, tile) - - A = TensorDescription(element_a, layout[0], channel_count) - B = TensorDescription(element_b, layout[1], channel_count) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor - - # - # Conv2d Fprop - # - if ConvKind.Fprop in conv_kinds: - - # Strided support for Analytic and Optimized Fprop - for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# Convolution for 3D operations -def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): - - element_a, element_b, element_c, element_epilogue = data_type - - # one exceptional case - alignment_c = min(8, alignment) - - # iterator algorithm (analytic and optimized) - iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] - - # by default, only generate the largest tile size and optimized iterators - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - iterator_algorithms = [IteratorAlgorithm.Optimized] - - operations = [] - - # All tile sizes for Conv3dFprop and Conv3dWgrad - for tile in tile_descriptions: - A = TensorDescription(element_a, layout, alignment) - B = TensorDescription(element_b, layout, alignment) - C = TensorDescription(element_c, layout, alignment_c) - - # - # Conv3d Fprop - # - if ConvKind.Fprop in conv_kinds: - # Strided support for Analytic and Optimized Fprop - for iterator_algorithm in iterator_algorithms: - new_operation = Conv3dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided) - manifest.append(new_operation) - operations.append(new_operation) - # - # Conv3d Wgrad - # - if ConvKind.Wgrad in conv_kinds: - - # Strided support for Analytic and Optimized Wgrad - for iterator_algorithm in iterator_algorithms: - new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) - manifest.append(new_operation) - operations.append(new_operation) - - # All tile sizes for Conv3dDgrad - for tile in tile_descriptions: - - A = TensorDescription(element_a, layout, alignment) - B = TensorDescription(element_b, layout, alignment) - C = TensorDescription(element_c, layout, alignment_c) - - # - # Conv3d Dgrad - # - if ConvKind.Dgrad in conv_kinds: - # Unity stride for Optimized Dgrad - new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - # Strided support for Analytic Dgrad - # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs - new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -# Convolution for Depthwise 2d conv -def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - - element_a, element_b, element_c, element_epilogue = data_type - - # iterator algorithm (FixedStrideDilation, Optimized) - iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] - - # by default, only generate the largest tile size, largest alignment, and optimized iterator - if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] - alignment_constraints = [alignment_constraints[0],] - - operations = [] - - for tile in tile_descriptions: - for alignment in alignment_constraints: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor - - if ConvKind.Fprop in conv_kinds: - - # Strided support for Optimized and FixedStridedDilation Depthwise Conv - for iterator_algorithm in iterator_algorithms: - stride_support = StrideSupport.Strided - if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation: - if tile.stride == [-1, -1] or tile.dilation == [-1,-1]: - continue - stride_support = StrideSupport.Fixed - - if iterator_algorithm == IteratorAlgorithm.Optimized: - if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: - continue - new_operation = Conv2dOperation(ConvKind.Fprop, - iterator_algorithm, - tile.minimum_compute_capability, - tile, - A, B, C, - element_epilogue, - stride_support, - epilogue_functor, - swizzling_functor_, - group_mode=GroupMode.Depthwise) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - -class ConvOperation3x: - """All parameters of a CUTLASS 3 convolution operation. - - Unlike CUTLASS 2 convolutions, CUTLASS 3 convolutions do not - distinguish between 2-D and 3-D convolutions by kernel class name. - Instead, for CUTLASS 3 convolutions, the tensor layouts encode - whether the convolution is 2-D or 3-D. Thus, this class deduces - the OperationKind (either Conv2d or Conv3d) from the layouts, - rather than taking it as a constructor parameter. - """ - def __init__(self, - conv_kind: ConvKind, - tile_description: TileDescription, - A: TensorDescription, - B: TensorDescription, - C: TensorDescription, - element_compute: Optional[DataType] = None, - D: Optional[TensorDescription] = None, - kernel_schedule: KernelScheduleType = KernelScheduleType.ScheduleAuto, - epilogue_schedule: EpilogueScheduleType = EpilogueScheduleType.ScheduleAuto, - tile_scheduler: TileSchedulerType = TileSchedulerType.Default, - log_indent_level: int = 1): - log_debug_line(f'ConvOperation3x::init: conv_kind: {conv_kind}', log_indent_level) - log_indent_level = log_indent_level + 1 - - self.conv_kind = conv_kind - self.tile_description = tile_description - self.A = A - self.B = B - self.C = C - self.element_compute = C.element if element_compute is None else element_compute - self.kernel_schedule = kernel_schedule - self.epilogue_schedule = epilogue_schedule - - self.arch = tile_description.minimum_compute_capability - self.tile_scheduler = tile_scheduler - if D == None: - self.D = C - else: - self.D = D - - self.is_3x = True - self.group_mode = GroupMode.NoneGroup # CUTLASS 3 convolutions currently aren't grouped - - operation_kind = None - for layout in (A.layout, B.layout, C.layout): - assert(isinstance(layout, LayoutType)) - new_operation_kind = convolution_tensor_layout_type_to_operation_kind(layout) - if operation_kind is None: - operation_kind = new_operation_kind - else: # CUTLASS 3 convolutions don't permit mixing 2-D and 3-D layouts. - assert(operation_kind == new_operation_kind) - assert(operation_kind is not None) - self.operation_kind = operation_kind - - def __str__(self): - return f"ConvOperation3x: operation_kind={self.operation_kind}, conv_kind={self.conv_kind}, tile_description={self.tile_description}" - - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - - def is_mixed_input(self): - return self.A.element != self.B.element - - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - if self.is_complex(): - return get_complex_from_real(accum) - return accum - - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and', - } - - tensor_ops = [ - OpcodeClass.TensorOp, - OpcodeClass.WmmaTensorOp, - OpcodeClass.SparseTensorOp, - OpcodeClass.BlockScaledTensorOp, - ] - - is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops - - if is_tensor_op: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind]) - - def extended_name(self): - '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' - extended_name = "{core_name}_{element_a}{layout_a}_{element_b}{layout_b}_{element_acc}_{element_c}_{element_d}{layout_c}".format( - element_a = DataTypeNames[self.A.element], - layout_a = ShortLayoutTypeNames[self.A.layout], - element_b = DataTypeNames[self.B.element], - layout_b = ShortLayoutTypeNames[self.B.layout], - element_acc = DataTypeNames[self.accumulator_type()], - element_c = DataTypeNames[self.C.element], - layout_c = ShortLayoutTypeNames[self.C.layout], - element_d = DataTypeNames[self.D.element], - core_name = self.core_name()) - - return extended_name - - # Generates a short string representing underlying kernel schedule type - def kernel_schedule_name(self): - return KernelScheduleSuffixes[self.kernel_schedule] - - # Generates a short string representing underlying epilogue schedule type - def epilogue_schedule_name(self): - return EpilogueScheduleSuffixes[self.epilogue_schedule] - - # Generate a short string representing the operation class - def opcode_class_name(self): - return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - # Generates the full kernel function name - def configuration_name(self): - ''' The full function name indicates architecture, extended name, tile size, and layout. ''' - kernel_name_template = "cutlass3x_sm{ar}_{op}_{ex}{ct}{cs}_{l}_align{al}{t}{k}{e}" - return kernel_name_template.format( - ar = self.arch, - op = self.opcode_class_name(), - ex = self.extended_name(), - ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "", - cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), - l = self.tile_description.stages, - al = str(max(self.A.alignment, self.B.alignment)), - t = TileSchedulerSuffixes[self.tile_scheduler], - k = self.kernel_schedule_name(), - e = self.epilogue_schedule_name()) - - def procedural_name(self): - return self.configuration_name() - -def convolution_tensor_layout_type_to_operation_kind(layout: LayoutType) -> OperationKind: - if layout == LayoutType.TensorNHWC or layout == LayoutType.TensorKCSR: - return OperationKind.Conv2d - elif layout == LayoutType.TensorNDHWC or layout == LayoutType.TensorKCSRT: - return OperationKind.Conv3d - else: - raise RuntimeError(f'LayoutType {layout} does not have a corresponding OperationKind') - -def CreateConvOperator3x(manifest: Manifest, - dims_and_alignments: Sequence[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]], - tile_descriptions: Sequence[Sequence[TileDescription]], - data_types, - schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ - [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], - complex_transforms: Optional[Sequence[ComplexTransform]] = None, - tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], - conv_kind: ConvKind = ConvKind.Fprop, - log_indent_level: int = 1): - """ - Create zero or more CUTLASS 3 two-dimensional convolution operators. - - Create a CUTLASS 3 two-dimensional convolution operator - for all feasible combinations of the input parameters. - Add the operators to the manifest. - - dims_and_alignments: 3-level list. Each outer list term is a list [A, B, C]. - Each inner list (A, B, or C) has the form [num_spatial_dimensions, alignment]. - Both are integers; the first is the number of spatial dimensions - (currently, only 2 or 3 are supported), and the second is the byte alignment. - We deduce the operation_kind (either OperationKind.Conv2d or OperationKind.Conv3d) - from num_spatial_dimensions. - - This function doesn't take layouts, unlike the GEMM functions. - CUTLASS 3 convolutions currently support three input layouts: - - * TensorNWC for 1-D convolutions, - * TensorNHWC for 2-D convolutions, and - * TensorNDHWC for 3-D convolutions. - - Output (C and D) layouts are the same as input layouts, - except for Wgrad convolutions, where the layouts are - - * TensorKCS for 1-D convolutions, - * TensorKCSR for 2-D convolutions, and - * TensorKCSRT for 3-D convolutions. - - The output layouts are completely constrained by the input layouts - and the convolution kind. - - tile_descriptions: 2-level list. - Outer level has one list per math instruction. - Inner level has one TileDescription for each cluster shape. - - data_types: Either a single data_type dictionary, or a list of them. - Keys: 'a_type', 'b_type', 'c_type', 'd_type', 'acc_type', 'epi_type' - - complex_transforms: Optional list of pairs. - First element of each pair is the complex transform for A, and - second element of each pair is the complex transform for B. - - schedule_pairs: [(kernel_schedule, epilogue_schedule), ...] - - conv_kind: Convolution kind (Fprop, Dgrad, or Wgrad). - """ - log_debug_line('CreateConvOperator3x', log_indent_level) - log_indent_level = log_indent_level + 1 - log_debug_line(f'conv_kind: {conv_kind}', log_indent_level) - - for triple in dims_and_alignments: - assert(isinstance(triple, tuple) or isinstance(triple, list)) - assert(len(triple) == 3) - - spatial_dimensionality = None # to be determined by loop below - - for entry in triple: # [A, B, C] - assert(len(entry) == 2) - [dim, alignment] = entry - assert(type(dim) is int) - assert(dim == 2 or dim == 3) - assert(type(alignment) is int) - assert(alignment > 0) - if spatial_dimensionality is None: - spatial_dimensionality = dim - else: - # A, B, and C need to have the same spatial dimensionality - assert(spatial_dimensionality == dim) - - def input_and_output_layouts(spatial_dim: int, kind: ConvKind) -> Tuple[LayoutType, LayoutType]: - if spatial_dim == 1: - input_layout = LayoutType.TensorNWC - if kind == ConvKind.Wgrad: - output_layout = LayoutType.TensorKCS - else: - output_layout = input_layout - elif spatial_dim == 2: - input_layout = LayoutType.TensorNHWC - if kind == ConvKind.Wgrad: - output_layout = LayoutType.TensorKCSR - else: - output_layout = input_layout - elif spatial_dim == 3: - input_layout = LayoutType.TensorNDHWC - if kind == ConvKind.Wgrad: - output_layout = LayoutType.TensorKCSRT - else: - output_layout = input_layout - else: - assert(False) - return (input_layout, output_layout) - - def dims_to_layouts(A_B_C: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]) -> \ - Tuple[Tuple[LayoutType, int], Tuple[LayoutType, int], Tuple[LayoutType, int]]: - [A, B, C] = A_B_C - [spatial_dim, alignment] = A - [input_layout, output_layout] = input_and_output_layouts(spatial_dim, conv_kind) - return ((input_layout, A[1]), - (input_layout, B[1]), - (output_layout, C[1])) - - # layouts: list of triples (A, B, C). - # Each of A, B, and C has the form [layout, alignment]. - layouts = [dims_to_layouts(A_B_C) for A_B_C in dims_and_alignments] - - if type(data_types) is dict: - data_types = [data_types] - - for s in schedule_pairs: - assert(len(s) == 2) - - if complex_transforms is None: - complex_transforms = [(ComplexTransform.none, ComplexTransform.none)] - - # product produces a one-pass generator, so the loop must call it anew each time. - def make_combinations(): - return product( - layouts, - tile_descriptions, - data_types, - complex_transforms, - schedule_pairs, - tile_schedulers - ) - - operations = [] - for layout_triple, tile_description, data_type, complex_transform_pair, schedule_pair, tile_scheduler in make_combinations(): - A_layout, A_alignment = layout_triple[0] - A_xform = complex_transform_pair[0] - B_layout, B_alignment = layout_triple[1] - B_xform = complex_transform_pair[1] - C_layout, C_alignment = layout_triple[2] - D_layout = C_layout - D_alignment = C_alignment - - A = TensorDescription(data_type["a_type"], A_layout, A_alignment, A_xform) - B = TensorDescription(data_type["b_type"], B_layout, B_alignment, B_xform) - C = TensorDescription(data_type["c_type"], C_layout, C_alignment) - D = TensorDescription(data_type["d_type"], D_layout, D_alignment) - element_compute = data_type.get("epi_type", data_type["acc_type"]) - kernel_schedule, epilogue_schedule = schedule_pair - - operation = ConvOperation3x(conv_kind=conv_kind, - tile_description=tile_description, - A=A, - B=B, - C=C, - element_compute=element_compute, - D=D, - kernel_schedule=kernel_schedule, - epilogue_schedule=epilogue_schedule, - tile_scheduler=tile_scheduler, - log_indent_level=log_indent_level) - log_debug_line(f'Created ConvOperation3x: {str(operation)}', log_indent_level) - manifest.append(operation) - operations.append(operation) - - return operations - -################################################################################################### -################################################################################################### - -# -def GenerateSM50_Simt(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - MathInstruction( \ - [1, 1, 1], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 50 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - if math_inst.element_a == DataType.f32: - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -# -def GenerateSM50_Simt_complex(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add_complex), - ] - - min_cc = 50 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, - DataType.cf32, - DataType.cf32, - DataType.cf32, - ] - - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -# -def GenerateSM50(manifest, cuda_version): - GenerateSM50_Simt(manifest, cuda_version) - GenerateSM50_Simt_complex(manifest, cuda_version) - -################################################################################################### -################################################################################################### - -# -def GenerateSM60_Simt(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 60 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) -# -def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 60 - max_cc = 1024 - - alignment_constraints = [8,] - - filter_3x3 = [3, 3] - filter_5x5 = [5, 5] - - # [stride_h, stride_w] - # [-1, -1] means all stride size. - strides = [[-1,-1], [1, 1], [2, 2]] - # [dilation_h, dilation_w] - # [-1, -1] means all dilation size. - dilations = [[-1,-1], [1, 1], [2, 2]] - - #groups per thread block - g16 = 16 - g32 = 32 - g64 = 64 - - #output shape per thread block - npq_1x4x4 = [1, 4, 4] - npq_1x8x8 = [1, 8, 8] - npq_1x10x10 = [1, 10, 10] - - tile_descriptions = [] - for math_inst in math_instructions: - for stride, dilation in product(strides, dilations): - tile_descriptions.extend([ - # filter3x3 ThreadBlock_output, filter, stage, warp - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - - Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), - - # filter5x5 ThreadBlock_output, filter, stage, warp - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - - Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc) - ]) - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -# -def GenerateSM60(manifest, cuda_version): - GenerateSM60_Simt(manifest, cuda_version) - GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version) - -################################################################################################### -################################################################################################### - -# -def GenerateSM61_Simt(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 4], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 61 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) -# - -# -def GenerateSM61(manifest, cuda_version): - GenerateSM61_Simt(manifest, cuda_version) - -################################################################################################### -################################################################################################### - -# -def GenerateSM70_TensorOp_884(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 70 - max_cc = 75 - - alignment_constraints = [8, 4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) - -# -def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [8, 8, 4], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 70 - max_cc = 75 - - alignment_constraints = [8, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, complex_transforms) - - -# -def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 16, 16], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.WmmaTensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 16, 16], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.WmmaTensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 70 - max_cc = 1024 - - alignment_constraints = [8,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - -# -################################################################################################## -# - -def GenerateSM70(manifest, cuda_version): - GenerateSM70_TensorOp_884(manifest, cuda_version) - GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version) - - # To limit build size, WMMA GEMMs are disabled for now. - # - #GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version) - -################################################################################################### -################################################################################################### - -# -def GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst): - - min_cc = 75 - max_cc = 1024 - - tile_descriptions = [ - TileDescription([128, 64, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 2], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - - CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) - CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [1, 2, 4]) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) - CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [1, 2, 4]) - -# -def GenerateSM75_TensorOp_1688(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 75 - max_cc = 1024 - - alignment_constraints = [8, 4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) - - # Separate generator for 'few channels' specializations - GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst) - -# - -# -def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 75 - max_cc = 1024 - - alignment_constraints = [8, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, complex_transforms) - -# -def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 16], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [8, 8, 16], \ - DataType.u8, DataType.u8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 75 - max_cc = 90 - - alignment_constraints = [16,] - alignment_constraints_small_channels = [16, 8, 4] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), - - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - DataType.s32, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 - else: - op.C.alignment = 8 - -# - -# -def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 16], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [8, 8, 16], \ - DataType.u8, DataType.u8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 75 - max_cc = 90 - - alignment_constraints = [16,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - op.C.alignment = 8 -# - -# -def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 32], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [8, 8, 32], \ - DataType.u4, DataType.u4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 75 - max_cc = 89 - - alignment_constraints = [32,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - DataType.s32, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 - elif op.tile_description.threadblock_shape[1] == 64: - op.C.alignment = 8 - else: - op.C.alignment = 8 - -# - -# -def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): - return - - layouts = [ - (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 32], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [8, 8, 32], \ - DataType.u4, DataType.u4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 75 - max_cc = 89 - - alignment_constraints = [32,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - op.C.alignment = 16 -# - -# -def GenerateSM75_TensorOp_88128(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [8, 8, 128], \ - DataType.b1, DataType.b1, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.xor_popc), - ] - - min_cc = 75 - max_cc = { - MathOperation.xor_popc: 89, - MathOperation.and_popc: 90 - } - - alignment_constraints = [128,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - ] - - data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - -# - -# -def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 16, 16], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.WmmaTensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 75 - max_cc = 1024 - - alignment_constraints = [16,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - DataType.f32, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) -# - -# -def GenerateSM75_Simt_complex(manifest, cuda_version): - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add_complex), - ] - - min_cc = 75 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc) - ] - data_type = [ - DataType.cf32, - DataType.cf32, - DataType.cf32, - DataType.cf32 - ] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -def GenerateSM75(manifest, cuda_version): - GenerateSM75_TensorOp_1688(manifest, cuda_version) - GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version) - GenerateSM75_TensorOp_8816_TN(manifest, cuda_version) - GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version) - GenerateSM75_TensorOp_8832_TN(manifest, cuda_version) - GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version) - GenerateSM75_TensorOp_88128(manifest, cuda_version) - #GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version) - GenerateSM75_Simt_complex(manifest, cuda_version) - - -################################################################################################### -################################################################################################### - -# -def GenerateSM80_TensorOp_16816(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 16], \ - DataType.bf16, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [8, 4, 2] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) - CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) - CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) - CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) - CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8) -# - -# -def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 32], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 32], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 32], \ - DataType.bf16, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [8] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - -# - -# -def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 16], \ - DataType.bf16, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [8, ] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, complex_transforms) - -# -def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - # Upcast on Operand A - math_instructions = [ - MathInstruction( \ - [16, 8, 16], \ - DataType.s8, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.u8, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.s8, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.u8, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.s8, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.u8, DataType.f16, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - ] - - min_cc = 80 - max_cc = 1024 - - # For mixed-input alignment constraints are a list of lists, where the - # inner list contains the alignment constraints for operands/matrices - # [[alignA, alignB, alignC],..] - alignment_constraints = [[16, 8, 8],] - - for math_inst in math_instructions: - tile_descriptions = [ - # 128x128 - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x64 - TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x32 - TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x16 - TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_b != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_b, - math_inst.element_accumulator, - ] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - for op in operations: - if (DataTypeSize[op.C.element] == 16) and \ - (op.tile_description.threadblock_shape[1] <= 32): - op.C.alignment = 4 - -# -def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.s8, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.u8, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.bf16, DataType.s8, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.bf16, DataType.u8, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.s8, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - MathInstruction( \ - [16, 8, 16], \ - DataType.f16, DataType.u8, DataType.f16, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - ] - - min_cc = 80 - max_cc = 1024 - - # For mixed-input alignment constraints are a list of lists, where the - # inner list contains the alignment constraints for operands/matrices - # [[alignA, alignB, alignC],..] - alignment_constraints = [[8, 16, 8],] - - for math_inst in math_instructions: - tile_descriptions = [ - # 128x128 - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x64 - TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x32 - TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 9, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - # 128x16 - TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 16, 32], 9, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), - # 256x16 - TileDescription([256, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - for op in operations: - if op.tile_description.threadblock_shape[1] <= 32: - op.C.alignment = 4 - -# -def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 32], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [16, 8, 32], \ - DataType.u8, DataType.u8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 80 - max_cc = 1024 - smem_usage = 164 - - alignment_constraints = [16,] - alignment_constraints_small_channels = [16, 8, 4] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] - data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - if op.tile_description.threadblock_shape[0] == 32: - op.C.alignment = 8 - else: - op.C.alignment = 16 - else: - op.C.alignment = 8 - -# - -def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - # Upcast on Operand A - math_instructions = [ - MathInstruction( \ - [16, 8, 32], \ - DataType.s4, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - ] - - min_cc = 80 - max_cc = 1024 - - # For mixed-input alignment constraints are a list of lists, where the - # inner list contains the alignment constraints for operands/matrices - # [[alignA, alignB, alignC],..] - alignment_constraints = [[32, 16, 4],] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - alignment_constraints = [[32, 16, 16],] - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_b, - DataType.f32 - ] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - if op.tile_description.threadblock_shape[0] == 32: - op.C.alignment = 8 - else: - op.C.alignment = 16 - else: - op.C.alignment = 8 -# - -# -def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - # Upcast on Operand B - math_instructions = [ - MathInstruction( \ - [16, 8, 32], \ - DataType.s8, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_mixed_input_upcast), - ] - - min_cc = 80 - max_cc = 1024 - - # For mixed-input alignment constraints are a list of lists, where the - # inner list contains the alignment constraints for operands/matrices - # [[alignA, alignB, alignC],..] - alignment_constraints = [[16, 32, 4],] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - alignment_constraints = [[16, 32, 16],] - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - DataType.f32, - ] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - if op.tile_description.threadblock_shape[0] == 32: - op.C.alignment = 8 - else: - op.C.alignment = 16 - else: - op.C.alignment = 8 -# - -# -def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 64], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [16,] - - tile_descriptions = [ - TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32] - data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32] - - CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - operations = [] - - operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 - else: - op.C.alignment = 8 -# - -# -def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 32], \ - DataType.s8, DataType.s8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [16, 8, 32], \ - DataType.u8, DataType.u8, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [16,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - op.C.alignment = 8 -# - -# -def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 64], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [16, 8, 64], \ - DataType.u4, DataType.u4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 80 - max_cc = 1024 - alignment_constraints = [32,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] - data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 - elif op.tile_description.threadblock_shape[1] == 64: - op.C.alignment = 8 - else: - op.C.alignment = 8 -# - -# -def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 128], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate) - - min_cc = 80 - max_cc = 1024 - alignment_constraints = [32,] - - tile_descriptions = [ - TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32] - data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32] - - CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - operations = [] - - operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - if op.tile_description.threadblock_shape[1] > 128: - op.C.alignment = 16 - else: - op.C.alignment = 8 -# - -# -def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 64], \ - DataType.s4, DataType.s4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - MathInstruction( \ - [16, 8, 64], \ - DataType.u4, DataType.u4, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_saturate), - ] - - min_cc = 80 - max_cc = 1024 - alignment_constraints = [32,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - - conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) - - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - - for op in operations: - op.C.alignment = 16 -# - -# -def GenerateSM80_TensorOp_168256(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 256], \ - DataType.b1, DataType.b1, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.xor_popc), - MathInstruction( \ - [16, 8, 256], \ - DataType.b1, DataType.b1, DataType.s32, \ - OpcodeClass.TensorOp, \ - MathOperation.and_popc), - ] - - min_cc = 80 - max_cc = { - MathOperation.xor_popc: 89, - MathOperation.and_popc: 90 - } - - alignment_constraints = [128,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), - ] - - data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - -# - -# -def GenerateSM80_TensorOp_1688(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f16, DataType.f16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_f16), - MathInstruction( \ - [16, 8, 8], \ - DataType.bf16, DataType.bf16, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_bf16), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -def GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_fast_f32) - - min_cc = 80 - max_cc = 1024 - - tile_descriptions = [ - TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 - ] - - alignment_constraints = [1,] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) - - -# -def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 16], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [4] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] - - CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_1688_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 80 - max_cc = 1024 - - tile_descriptions = [ - TileDescription([128, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 - ] - - alignment_constraints = [1,] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1, 2, 4] # Alignment only applies to A in SYRK - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32] - - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, DataType.cf32, DataType.cf32 - ] - - alignment_constraints = [1,] - - # SYRK - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HERK - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1, 2, 4] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 - ] - - alignment_constraints = [1,] - - complex_transforms = [ - ComplexTransform.none, ComplexTransform.conj, - ] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - # A and B have same layouts - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [ - 1, 2, 4 - ] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] - - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_instructions = [ - MathInstruction( \ - [16, 8, 8], \ - DataType.tf32, DataType.tf32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex), - MathInstruction( \ - [16, 8, 8], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_fast_f32), - ] - - min_cc = 80 - max_cc = 1024 - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 - ] - - alignment_constraints = [1,] - - # SYMM - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HEMM - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM80_TensorOp_884(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_884_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) - -# -def GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64] - - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64] - - # SYRK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HERK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) - -# - -# -def GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ComplexTransform.none,] - - # SYRK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HERK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM80_TensorOp_884_trmm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints) -# - -# -def GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - ComplexTransform.none, ComplexTransform.conj, - ] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - - -# -def GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - ComplexTransform.none, ComplexTransform.conj, - ] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM80_TensorOp_884_symm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - # SYMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HEMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [8, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ComplexTransform.none,] - - # SYMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HEMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -################################################################################################### - -# -def GenerateSM80_Simt_f32(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - - -# -def GenerateSM80_Simt_f64(manifest, cuda_version): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) -# - - -################################################################################################## -# -def GenerateSM80_Simt_complex(manifest, cuda_version): - math_instructions = [ - MathInstruction( \ - [1, 1, 1], \ - DataType.f32, DataType.f32, DataType.f32, \ - OpcodeClass.Simt, \ - MathOperation.multiply_add_complex), - ] - - min_cc = 80 - max_cc = 1024 - - alignment_constraints = [1,] - - data_type = [ - DataType.cf32, - DataType.cf32, - DataType.cf32, - DataType.cf32 - ] - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - for math_inst in math_instructions: - - tile_descriptions = [ - TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) -# - -################################################################################################### - -# -def GenerateSM80(manifest, cuda_version): - GenerateSM80_TensorOp_16816(manifest, cuda_version) - GenerateSM80_SparseTensorOp_16832(manifest, cuda_version) - GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version) - GenerateSM80_TensorOp_1688(manifest, cuda_version) - GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) - GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version) - GenerateSM80_TensorOp_1688_complex(manifest, cuda_version) - # 3xTF32 - GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) - GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version) - GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version) - GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version) - GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version) - GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version) - GenerateSM80_TensorOp_1688_symm(manifest, cuda_version) - GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version) - GenerateSM80_TensorOp_884(manifest, cuda_version) - GenerateSM80_TensorOp_884_complex(manifest, cuda_version) - GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version) - GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version) - GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version) - GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version) - GenerateSM80_TensorOp_884_trmm(manifest, cuda_version) - GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version) - GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version) - GenerateSM80_TensorOp_884_symm(manifest, cuda_version) - GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) - GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) - GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) - GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) - GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) - GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) - GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) - GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) - GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) - GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) - GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version) - GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version) - GenerateSM80_TensorOp_168256(manifest, cuda_version) - GenerateSM80_Simt_f32(manifest, cuda_version) - GenerateSM80_Simt_f64(manifest, cuda_version) - GenerateSM80_Simt_complex(manifest, cuda_version) - -################################################################################################### - -def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc): - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) - ] - - math_instructions = [ - MathInstruction( - [16, 8, 32], - DataType.e4m3, DataType.e4m3, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 32], - DataType.e4m3, DataType.e5m2, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 32], - DataType.e5m2, DataType.e4m3, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 32], - DataType.e5m2, DataType.e5m2, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 32], - DataType.e4m3, DataType.e4m3, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 32], - DataType.e4m3, DataType.e5m2, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 32], - DataType.e5m2, DataType.e4m3, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 32], - DataType.e5m2, DataType.e5m2, element_acc, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - ] - - min_cc = 89 - max_cc = 100 - alignment_constraints = [16,] - alignment_constraints_small_channels = [16, 8, 4] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_types = [ - [ - math_inst.element_a, - math_inst.element_b, - DataType.f32, - math_inst.element_accumulator - ], - [ - math_inst.element_a, - math_inst.element_b, - DataType.bf16, - math_inst.element_accumulator - ], - ] - - operations = [] - for data_type in data_types: - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, - alignment_constraints, None, EpilogueFunctor.LinearCombination) - - conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) - operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - if op.tile_description.threadblock_shape[0] == 32: - op.C.alignment = 8 - else: - op.C.alignment = 16 - else: - op.C.alignment = 8 - -def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 4): - return - - GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32) - -def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16) - -# -def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): - - if ( - not CudaToolkitVersionSatisfies(cuda_version, 12, 4) - ): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) - ] - - math_instructions = [ - MathInstruction( - [16, 8, 64], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 64], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 64], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 64], - DataType.e5m2, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [16, 8, 64], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 64], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 64], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - MathInstruction( - [16, 8, 64], - DataType.e5m2, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_accum), - ] - - min_cc = 89 - max_cc = 89 - - alignment_constraints = [16,] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_types = [ - [ - math_inst.element_a, - math_inst.element_b, - DataType.f32, - math_inst.element_accumulator - ], - ] - - operations = [] - for data_type in data_types: - operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, - alignment_constraints, None, EpilogueFunctor.LinearCombination) - - for op in operations: - if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 - else: - op.C.alignment = 8 - -################################################################################################### - -# -def GenerateSM89(manifest, cuda_version): - GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version) - GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version) - GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version) - -################################################################################################### - - -try: - from .sm90_utils import ( - generate_fp16_bf16_math_instructions_sm90, - generate_tf32_math_instructions_sm90, - generate_int8_math_instructions_sm90, - generate_fp8_math_instructions_sm90, - generate_mixed_dtype_math_instructions_sm90, - make_sparse_math_instructions, - generate_tile_descriptions_sm90, - get_valid_schedules, - generate_data_types_from_math_instruction, - fix_alignments, - ) -except ImportError: - from sm90_utils import ( - generate_fp16_bf16_math_instructions_sm90, - generate_tf32_math_instructions_sm90, - generate_int8_math_instructions_sm90, - generate_fp8_math_instructions_sm90, - generate_mixed_dtype_math_instructions_sm90, - make_sparse_math_instructions, - generate_tile_descriptions_sm90, - get_valid_schedules, - generate_data_types_from_math_instruction, - fix_alignments, - ) - -def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], - ] - - math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_types = [data_type_w_source, data_type_wo_source] - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_type_mixed_w_source = generate_data_types_from_math_instruction( - math_inst, - element_source=math_inst.element_a, - element_dest=math_inst.element_a - ) - data_type_mixed_wo_source = generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.void, - element_dest=math_inst.element_a - ) - data_types.append(data_type_mixed_w_source) - data_types.append(data_type_mixed_wo_source) - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - gemm_kind=gemm_kind, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) - is_aligned = False - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], - ] - - math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_types = [data_type_w_source] - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_type_mixed_w_source = generate_data_types_from_math_instruction( - math_inst, - element_source=math_inst.element_a, - element_dest=math_inst.element_a - ) - data_types.append(data_type_mixed_w_source) - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - -def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], - ] - - math_instructions = make_sparse_math_instructions(generate_fp16_bf16_math_instructions_sm90(instantiation_level)) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_types = [data_type_w_source, data_type_wo_source] - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_type_mixed_w_source = generate_data_types_from_math_instruction( - math_inst, - element_source=math_inst.element_a, - element_dest=math_inst.element_a - ) - data_type_mixed_wo_source = generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.void, - element_dest=math_inst.element_a - ) - data_types.append(data_type_mixed_w_source) - data_types.append(data_type_mixed_wo_source) - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - ] - - math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - - for layout in layouts: - data_type_tf32 = generate_data_types_from_math_instruction(math_inst) - data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_type_f32 = copy.deepcopy(data_type_tf32) - data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) - data_type_f32["a_type"] = DataType.f32 - data_type_f32["b_type"] = DataType.f32 - data_type_f32["epi_type"] = DataType.f32 - data_type_f32_wo_source["a_type"] = DataType.f32 - data_type_f32_wo_source["b_type"] = DataType.f32 - data_type_f32_wo_source["epi_type"] = DataType.f32 - data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] - - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) - is_aligned = False - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], - ] - - math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - - for layout in layouts: - # Inconsistency: TF32 does not stamp out void-C - data_type_tf32 = generate_data_types_from_math_instruction(math_inst) - data_type_f32 = copy.deepcopy(data_type_tf32) - data_type_f32["a_type"] = DataType.f32 - data_type_f32["b_type"] = DataType.f32 - data_type_f32["epi_type"] = DataType.f32 - for data_type in [data_type_tf32, data_type_f32]: - # Inconsistency: alignments aren't fixed in TF32 / alignx - # layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - ] - - math_instructions = make_sparse_math_instructions(generate_tf32_math_instructions_sm90(instantiation_level)) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - - for layout in layouts: - data_type_tf32 = generate_data_types_from_math_instruction(math_inst) - data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_type_f32 = copy.deepcopy(data_type_tf32) - data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) - data_type_f32["a_type"] = DataType.f32 - data_type_f32["b_type"] = DataType.f32 - data_type_f32["epi_type"] = DataType.f32 - data_type_f32_wo_source["a_type"] = DataType.f32 - data_type_f32_wo_source["b_type"] = DataType.f32 - data_type_f32_wo_source["epi_type"] = DataType.f32 - data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] - - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], - ] - - math_instructions = generate_int8_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_type_int8_output = generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.s8, - element_dest=math_inst.element_a, - element_epilogue=DataType.f32 - ) - data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) - is_aligned = False - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - ] - - math_instructions = generate_int8_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_type_int8_output = generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.s8, - element_dest=math_inst.element_a, - element_epilogue=DataType.f32 - ) - data_types = [data_type_w_source, data_type_int8_output] - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], - ] - - math_instructions = make_sparse_math_instructions(generate_int8_math_instructions_sm90(instantiation_level)) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - # s8.u8 and u8.s8 wgmma variants require PTX 8.4 - if math_inst.element_a != math_inst.element_b and not CudaToolkitVersionSatisfies(cuda_version, 12, 4): - continue - data_type_w_source = generate_data_types_from_math_instruction(math_inst) - data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) - data_type_int8_output = generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.s8, - element_dest=math_inst.element_a, - element_epilogue=DataType.f32 - ) - data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] - - for layout in layouts: - for data_type in data_types: - layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout - ] - - math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_types = [] - fp8_types = [DataType.e4m3, DataType.e5m2] - valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] - valid_types_for_c = copy.deepcopy(valid_types_for_d) - valid_types_for_c.append(DataType.void) - for c_type, d_type in product(valid_types_for_c, valid_types_for_d): - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=c_type, - element_dest=d_type, - ) - ) - else: - for d_type in valid_types_for_d: - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.void, - element_dest=d_type, - ) - ) - - for layout in layouts: - for data_type in data_types: - # Inconsistency: alignments aren't fixed in FP8 - # layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - gemm_kind=gemm_kind, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - -def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout - ] - - math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) - tile_descriptions_ = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - tile_descriptions = list() - - for desc in tile_descriptions_: - desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]] - tile_descriptions.append(copy.deepcopy(desc)) - desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] - tile_descriptions.append(copy.deepcopy(desc)) - desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] - tile_descriptions.append(copy.deepcopy(desc)) - desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]] - tile_descriptions.append(copy.deepcopy(desc)) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_types = [] - fp8_types = [DataType.e4m3, DataType.e5m2] - valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] - valid_types_for_c = copy.deepcopy(valid_types_for_d) - valid_types_for_c.append(DataType.void) - for c_type, d_type in product(valid_types_for_c, valid_types_for_d): - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=c_type, - element_dest=d_type, - ) - ) - else: - for d_type in valid_types_for_d: - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.void, - element_dest=d_type, - ) - ) - - for layout in layouts: - for data_type in data_types: - # Inconsistency: alignments aren't fixed in FP8 - # layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - gemm_kind=gemm_kind, - enable_fp8_fast_acc=False, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK], - gemm_kind=gemm_kind) - - - -def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) - is_aligned = False - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], # TN Layout - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], # TN Layout - ] - - math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_types = [generate_data_types_from_math_instruction(math_inst)] - fp8_types = [DataType.e4m3, DataType.e5m2] - valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] - valid_types_for_c = copy.deepcopy(valid_types_for_d) - valid_types_for_c.append(DataType.void) - for c_type, d_type in product(valid_types_for_c, valid_types_for_d): - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=c_type, - element_dest=d_type, - ) - ) - - for layout in layouts: - for data_type in data_types: - # Inconsistency: alignments aren't fixed in FP8 - # layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - -def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) - is_aligned = True - - # layouts for ABC, their alignments will be fixed later based on the data type - layouts = [ - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], - ] - - valid_types_for_a_b_acc = [ - (DataType.e4m3, DataType.f16, DataType.f32), - (DataType.e4m3, DataType.bf16, DataType.f32), - (DataType.e5m2, DataType.f16, DataType.f32), - (DataType.e5m2, DataType.bf16, DataType.f32), - (DataType.s8, DataType.f16, DataType.f32), - (DataType.s8, DataType.bf16, DataType.f32), - (DataType.u8, DataType.f16, DataType.f32), - (DataType.u8, DataType.bf16, DataType.f32), - (DataType.s4, DataType.f16, DataType.f32), - (DataType.s4, DataType.bf16, DataType.f32), - (DataType.s4, DataType.e4m3, DataType.f32), - (DataType.s4, DataType.e5m2, DataType.f32), - (DataType.u4, DataType.f16, DataType.f32), - (DataType.u4, DataType.bf16, DataType.f32), - (DataType.u2, DataType.f16, DataType.f32), - (DataType.u2, DataType.bf16, DataType.f32), - (DataType.s2, DataType.f16, DataType.f32), - (DataType.s2, DataType.bf16, DataType.f32), - ] - # Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now. - #swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc] - #valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc - - math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc) - - valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] - valid_types_for_c = copy.deepcopy(valid_types_for_d) - - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_types = [] - - # Limit C/D types to avoid a giant number of instantiations. - # A typical use case for mixed dtype in DL is weight quantization (tensor A), - # therefore we can limit the output type to that of activation (tensor B). - valid_types_for_c = [math_inst.element_b] - valid_types_for_d = [math_inst.element_b] - - for c_type, d_type in product(valid_types_for_c, valid_types_for_d): - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=c_type, - element_dest=d_type, - ) - ) - - for layout in layouts: - for data_type in data_types: - # Fix alignments, DataTypeSize are in the unit of bits - alignment_bits = 128 - layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']] - layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']] - layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']] - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) - is_aligned = True - - # layouts for ABC and their alignments - layouts = [ - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout - ] - - math_instructions = make_sparse_math_instructions(generate_fp8_math_instructions_sm90(instantiation_level)) - tile_descriptions = generate_tile_descriptions_sm90( - math_instructions=math_instructions, - is_aligned=is_aligned, - level=instantiation_level) - - for tile_desc in tile_descriptions: - math_inst = tile_desc.math_instruction - data_types = [] - fp8_types = [DataType.e4m3, DataType.e5m2] - valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] - valid_types_for_c = copy.deepcopy(valid_types_for_d) - valid_types_for_c.append(DataType.void) - for c_type, d_type in product(valid_types_for_c, valid_types_for_d): - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=c_type, - element_dest=d_type, - ) - ) - else: - for d_type in valid_types_for_d: - data_types.append( - generate_data_types_from_math_instruction( - math_inst, - element_source=DataType.void, - element_dest=d_type, - ) - ) - - for layout in layouts: - for data_type in data_types: - # Inconsistency: alignments aren't fixed in FP8 - # layout = fix_alignments(data_type, layout, alignment_bits=128) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_desc, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_type, - instantiation_level=instantiation_level, - layout=layout, - ) - - if len(schedules): - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) - if len(stream_k_schedules): - assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) - CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]) - - -def GenerateSM90_TensorOp_1684(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = MathInstruction( - [16, 8, 4], - DataType.f64, DataType.f64, DataType.f64, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateGemmOperator(manifest, layouts, tile_descriptions, - data_type, alignment_constraints) - -# - -# -def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - (ComplexTransform.conj, ComplexTransform.none), - (ComplexTransform.none, ComplexTransform.conj), - (ComplexTransform.conj, ComplexTransform.conj) - ] - - CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64] - - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64] - - # SYRK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HERK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) - -# - -# -def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ComplexTransform.none,] - - # SYRK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HERK computation - CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints) -# - -# -def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - ComplexTransform.none, ComplexTransform.conj, - ] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - - -# -def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - diag_types = [ - DiagType.NonUnit, DiagType.Unit, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ - ComplexTransform.none, ComplexTransform.conj, - ] - - CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ - data_type, alignment_constraints, complex_transforms) -# - -# -def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] - - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) -# - -# -def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - # SYMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HEMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - -# -def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): - return - - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - - side_modes = [ - SideMode.Left, SideMode.Right, - ] - - fill_modes = [ - FillMode.Lower, FillMode.Upper, - ] - - math_inst = \ - MathInstruction( \ - [16, 8, 4], \ - DataType.f64, DataType.f64, DataType.f64, \ - OpcodeClass.TensorOp, \ - MathOperation.multiply_add_complex_gaussian) - - min_cc = 90 - max_cc = 90 - - alignment_constraints = [1,] - - tile_descriptions = [ - TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), - #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), - ] - - data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] - - complex_transforms = [ComplexTransform.none,] - - # SYMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.symmetric) - - # HEMM computation - CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ - data_type, alignment_constraints, BlasMode.hermitian) -# - - - -# Blackwell SM 100 generators - -try: - import cutlass_library.sm100_utils - from cutlass_library.sm100_utils import ( - generate_tf32_math_instructions_sm100, - generate_16b_math_instructions_sm100, - generate_f8f6f4_math_instructions_sm100, - generate_mxf8f6f4_math_instructions_sm100, - generate_mxf4nvf4_math_instructions_sm100, - generate_fp8_math_instructions_sm100, - generate_cluster_shapes_sm100, - get_pruning_level_from_global_level - ) -except ImportError: - import sm100_utils - from sm100_utils import ( - generate_tf32_math_instructions_sm100, - generate_16b_math_instructions_sm100, - generate_f8f6f4_math_instructions_sm100, - generate_mxf8f6f4_math_instructions_sm100, - generate_mxf4nvf4_math_instructions_sm100, - generate_fp8_math_instructions_sm100, - generate_cluster_shapes_sm100, - get_pruning_level_from_global_level - ) - -################################################################################################### - -def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: - if DataTypeSize[data_type] < 8 and is_f8f6f4: - return int(128) - return int(16 * 8 / DataTypeSize[data_type]) - -sm100_cluster_shape_1sm = [ - [4,4,1] - , DynamicClusterShape -] - -sm100_cluster_shape_2sm = [ - # cluster_m % 2 == 0 for 2sm - [4,4,1] - , DynamicClusterShape -] - -def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], - ] - - data_types = [ - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - -def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) - - # layouts for ABC and their alignments. C alignment will be set later based on output type - layouts = [ - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) - - min_cc = 100 - max_cc = thor_sm - grouped = is_grouped(gemm_kind) - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[kernel_schedule, epi_schedule]], - tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[kernel_schedule, epi_schedule]], - tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - if grouped: - epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm - elif math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - -def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - grouped = is_grouped(gemm_kind) - - math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped) - epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, epi_schedule]], - tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - - # 2xSM MMA kernels - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - - if grouped: - epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm - elif math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - -def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) - - grouped = is_grouped(gemm_kind) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - ] - - min_cc = 100 - max_cc = 100 - epi_type = DataType.f32 - - pruning_level = get_pruning_level_from_global_level(instantiation_level) - - math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) - - tile_schedulers = [ - TileSchedulerType.Default, - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, - [math_inst.instruction_shape[0], math_inst.instruction_shape[1], - math_inst.instruction_shape[2] * 4])) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, - [1, math_inst.instruction_shape[1], - math_inst.instruction_shape[2] * 4])) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, - [math_inst.instruction_shape[0], 1, - math_inst.instruction_shape[2] * 4])) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - - is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"]) - is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"]) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) - epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) - epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], - tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) - -def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): - - # SM100 MMA with mixed F4/F6/F8 inputs + without block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) - - grouped = is_grouped(gemm_kind) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - - def change_priority_func(shapes_1sm, shapes_2sm): - shapes_1sm[(1,2,1)] = 6 - shapes_1sm[(1,4,1)] = 6 - shapes_2sm[(2,2,1)] = 6 - shapes_2sm[(2,4,1)] = 6 - shapes_2sm[(4,2,1)] = 6 - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - } - ] - - for kernel_data_type in kernel_data_types: - # Filter out some kernel - if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ - ( kernel_data_type["d_type"] == DataType.e5m2 ): - continue - - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - } - ] - - for kernel_data_type in kernel_data_types: - # Filter some kernel - if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ - ( kernel_data_type["d_type"] == DataType.e5m2 ): - continue - - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - if math_inst.instruction_shape[0] == 128: - CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) - else: - CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) - -def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): - - # SM100 MMA with mixed F4/F6/F8 inputs + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) - - grouped = is_grouped(gemm_kind) - - layouts = [ - [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], - ] - - math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - - def change_priority_func(shapes_1sm, shapes_2sm): - shapes_1sm[(1,2,1)] = 6 - shapes_1sm[(1,4,1)] = 6 - shapes_2sm[(2,2,1)] = 6 - shapes_2sm[(2,4,1)] = 6 - shapes_2sm[(4,2,1)] = 6 - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) - - ab_types = [ - DataType.f4, DataType.f6, - DataType.e2m1, - DataType.e2m3, - DataType.e3m2, - DataType.e5m2, - DataType.e4m3, - ] - - acc_types = [ DataType.f32 ] - - def tile_schedulers(sfdtype): - # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, - # the epilogue is the traditional linear combination, for which we already have tests with stream-K. - if sfdtype["type"] == DataType.void or grouped: - return [TileSchedulerType.Default] - else: - return [TileSchedulerType.Default, TileSchedulerType.StreamK] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e3m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - for math_inst in math_instructions_2sm: - assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e3m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - ] - - # Set alignment d based on Destination format. - for data_type in data_types: - for layout in layouts: - # alignment for a - layout[0][1] = get_tma_alignment_elt(data_type["a_type"]) - # alignment for b - layout[1][1] = get_tma_alignment_elt(data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(data_type["d_type"]) - for tile in tile_descriptions: - math_inst = tile.math_instruction - # Filter some kernels that does not meet the alignment requirements. - if layout[0][0] == LayoutType.ColumnMajor: - if math_inst.instruction_shape[0] // 2 % layout[0][1] != 0: - continue - else: - if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[0][1] != 0: - continue - - if layout[1][0] == LayoutType.RowMajor: - if math_inst.instruction_shape[1] // 2 % layout[1][1] != 0: - continue - else: - if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: - continue - - if grouped: - CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], - [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - elif math_inst.instruction_shape[0] == 128: - CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], - [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - else: - CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], - [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - - -def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): - # SM100 MMA with F4 + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) - - grouped = is_grouped(gemm_kind) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], - ] - - math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - - def change_priority_func(shapes_1sm, shapes_2sm): - shapes_1sm[(1,2,1)] = 6 - shapes_1sm[(1,4,1)] = 6 - shapes_2sm[(2,2,1)] = 6 - shapes_2sm[(2,4,1)] = 6 - shapes_2sm[(4,2,1)] = 6 - - cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) - - acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions - - def tile_schedulers(sfdtype): - # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, - # the epilogue is the traditional linear combination, for which we already have tests with stream-K. - if sfdtype["type"] == DataType.void or grouped: - return [TileSchedulerType.Default] - else: - return [TileSchedulerType.Default, TileSchedulerType.StreamK] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - - if thor_sm in manifest.compute_capabilities_baseline : - if [4,4,1] in cluster_shapes_1sm : - cluster_shapes_1sm.remove([4,4,1]) - if [4,4,1] in cluster_shapes_2sm : - cluster_shapes_2sm.remove([4,4,1]) - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - assert math_inst.instruction_shape[2] * 4 == 256 - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for layout in layouts: - for data_type in data_types: - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): - data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): - continue - - # E2M1 x E2M1, vector size 32, E8 - # E2M1 x E2M1, vector size 16, UE4M3 - isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) - epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) - nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) - fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) - - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules - , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind - ) - if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules - , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind - ) - - for math_inst in math_instructions_2sm: - assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for layout in layouts: - for data_type in data_types: - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): - data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): - continue - - # E2M1 x E2M1, vector size 32, E8 - isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - - epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm - epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) - nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) - fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) - - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules - , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules - , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - -def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): - # SM100 MMA with F4 + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): - return - - grouped = is_grouped(gemm_kind) - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], - ] - - instruction_sizes_1sm = [ - [128, 128, 96], - ] - - instruction_sizes_2sm = [ - [256, 128, 96], - [256, 192, 96], - [256, 256, 96] - ] - - ab_types = [ - DataType.f4, - DataType.e2m1, - ] - - sf_types = [ - DataType.ue4m3, - DataType.ue8m0 - ] - - acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions - - def tile_schedulers(sfdtype): - # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, - # the epilogue is the traditional linear combination, for which we already have tests with stream-K. - if grouped: - return [TileSchedulerType.Default] - if sfdtype["type"] == DataType.void: - return [TileSchedulerType.Default] - else: - return [TileSchedulerType.Default, TileSchedulerType.StreamK] - - min_cc = 103 - max_cc = 103 - epi_type = DataType.f32 - - math_instructions_1sm = [] - - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - - for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - sf_type) - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - sf_type) - ) - - cluster_shapes_1sm = [ - [1,1,1], - # [1,2,1], - [2,1,1], - # [1,4,1], - [4,4,1], - DynamicClusterShape - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - 768], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - for data_type in data_types: - # Set alignment d based on Destination format. - if DataTypeSize[data_type["c_type"]] == 0 : - layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] - else: - layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): - data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): - continue - # E2M1 x E2M1, vector size 32, E8 - isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - - epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) - - nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] - nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] - nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] - fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] - fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] - fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] - nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] - fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] - - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, - nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, - fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - [4,1,1], - # [4,2,1], - [4,4,1], - DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - for data_type in data_types: - # Set alignment d based on Destination format. - if DataTypeSize[data_type["c_type"]] == 0 : - layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] - else: - layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): - data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): - continue - # E2M1 x E2M1, vector size 32, E8 - isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - - epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) - - nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] - nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] - nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] - fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] - fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] - fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] - nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] - fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] - - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, - nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, - fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - -def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - math_instructions_1sm = [ - MathInstruction( - [64, 128, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [128, 128, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }, - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - -def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - # Alignment requirement will be over-write below - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - kernel_data_types = [ - # void_c - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - ] - - math_instructions_1sm = [ - MathInstruction( - [128, 128, 16], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - math_instructions_2sm = [ - MathInstruction( - [256, 128, 16], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 16], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - -def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - # Alignment requirement will be over-write below - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - kernel_data_types = [ - # void_c - { - "a_type" : DataType.f16, - "b_type" : DataType.f16, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : DataType.f16, - "b_type" : DataType.f16, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - ] - - math_instructions_1sm = [ - MathInstruction( - [128, 128, 32], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - math_instructions_2sm = [ - MathInstruction( - [256, 128, 32], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - -def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - # Alignment requirement will be over-write below - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - kernel_data_types = [ - # void_c - { - "a_type" : DataType.s8, - "b_type" : DataType.s8, - "c_type" : DataType.void, - "d_type" : DataType.s8, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : DataType.s8, - "b_type" : DataType.s8, - "c_type" : DataType.s8, - "d_type" : DataType.s8, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - ] - - math_instructions_1sm = [ - MathInstruction( - [128, 128, 64], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add)] - - math_instructions_2sm = [ - MathInstruction( - [256, 128, 64], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - -def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - # Alignment requirement will be over-write below - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - kernel_data_types = [ - # NOTE: a/b type in kernel will be overwrite below. - #* void_c - # f8_f8_f32_void_f16 - { - "a_type" : DataType.e4m3, - "b_type" : DataType.e4m3, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - #* non-void_c - # f8_f8_f32_f16_f8 - { - "a_type" : DataType.e4m3, - "b_type" : DataType.e4m3, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - }, - ] - - math_instructions_1sm = [ - # Runtime DType - MathInstruction( - [128, 128, 64], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - math_instructions_2sm = [ - # Runtime DType - MathInstruction( - [256, 128, 64], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update input AB type - kernel_data_type["a_type"] = math_inst.element_a - kernel_data_type["b_type"] = math_inst.element_b - - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - for kernel_data_type in kernel_data_types: - # Update input AB type - kernel_data_type["a_type"] = math_inst.element_a - kernel_data_type["b_type"] = math_inst.element_b - - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_copy = copy.deepcopy(layouts) - for layout in layouts_copy: - # alignment for a, 2 for sparsity - layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - -def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - # Alignment requirement will be over-write below - [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - tile_schedulers = [ - TileSchedulerType.Default, TileSchedulerType.StreamK - ] - - math_instructions_1sm = [ - # Runtime Dtype - MathInstruction( - [128, 128, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - - MathInstruction( - [128, 128, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - math_instructions_2sm = [ - # Runtime DType - MathInstruction( - [256, 128, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - - MathInstruction( - [256, 128, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - # void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - ] - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_filtered = [] - for layout in layouts: - layout_filter = copy.deepcopy(layout) - # * A_K : Logical TileShape_K % 256 == 0 - # * A_M : TileShape_M % 128 == 0 - # * B_N : TileSize_N % 128 == 0 - # * B_K : TileSize_K % 128 == 0 - if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ - (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ - ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ - (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): - # alignment for a, 2 for sparsity - layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - layouts_filtered.append(layout_filter) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - # void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - ] - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_filtered = [] - for layout in layouts: - layout_filter = copy.deepcopy(layout) - # * A_K : Logical TileShape_K % 256 == 0 - # * A_M : TileShape_M % 128 == 0 - # * B_N : TileSize_N % 256 == 0 - # * B_K : TileSize_K % 128 == 0 - if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ - (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ - ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ - (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): - # alignment for a, 2 for sparsity - layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - layouts_filtered.append(layout_filter) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - -# Conv Utility functions -def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): - bit_alignment_required_by_tma = 128 - return ((dim, bit_alignment_required_by_tma // bit_per_element_A), # A - (dim, bit_alignment_required_by_tma // bit_per_element_B), # B - (dim, bit_alignment_required_by_tma // bit_per_element_C)) # C - -def make_math_instruction_w_output(data_types: Tuple[DataType, DataType, DataType, DataType], - instruction_shape: Tuple[int, int, int]) -> (MathInstruction, DataType): - default_opcode = OpcodeClass.TensorOp - default_math_op = MathOperation.multiply_add - [A_data_type, B_data_type, Acc_data_type, Out_data_type] = data_types - return (MathInstruction( - instruction_shape, - A_data_type, B_data_type, Acc_data_type, - default_opcode, - default_math_op - ), Out_data_type) - -""" -Generate CUTLASS 3 convolution kernel(s) for SM100. - -This is meant to be called from GenerateSM100. -""" -def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, - log_indent_level: int = 0): - log_debug_line('GenerateSM100_TensorOp_16b_UMMA_conv3x', log_indent_level) - log_indent_level = log_indent_level + 1 - - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - thor_sm = ThorSMRenumbering(cuda_version) - - minimum_compute_capability = 100 - maximum_compute_capability = thor_sm - - spatial_dims = [2, 3] - - conv_kinds = [ - ConvKind.Fprop, - ConvKind.Dgrad, - ConvKind.Wgrad - ] - - stages = 0 # zero means "deduce the number of stages automatically" - - data_types_and_instruction_shapes_1sm = [ - # ((A,B,Acc,C/D), (InstM,InstN,InstK)) - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (64, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (64, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (64, 128, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), - ] - math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), - data_types_and_instruction_shapes_1sm) - - cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] - - # tile_descriptions is a 2-level list. - # Each inner list is for each cluster shape. - for math_inst, output_type in math_instructions_w_output_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - cluster_multiplier = cluster_shape - # Unlike SM90, SM100 tile shape calculation includes cluster shape. - tile_shape = [ - math_inst.instruction_shape[0] * cluster_multiplier[0], - math_inst.instruction_shape[1] * cluster_multiplier[1], - math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] - ] - warp_count = [4, 1, 1] - tile_description = TileDescription( - tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, - cluster_shape) - tile_descriptions.append(tile_description) - - # It's typical to get the data types from the math instruction. - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : output_type, - "d_type" : output_type, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } - - dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] - - # Schedules - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 - epilogue_schedule = EpilogueScheduleType.ScheduleAuto - schedule_pairs = [ - (mainloop_schedule, epilogue_schedule) - ] - - for conv_kind in conv_kinds: - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = tile_descriptions, - data_types = data_type, - schedule_pairs = schedule_pairs, - conv_kind = conv_kind, - log_indent_level = log_indent_level) - - data_types_and_instruction_shapes_2sm = [ - # ((A,B,Acc,C/D), (InstM,InstN,InstK)) - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), - ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (256, 256, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), - ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (256, 256, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (256, 256, 16)), - ] - math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), - data_types_and_instruction_shapes_2sm) - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] - - for math_inst, output_type in math_instructions_w_output_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - # Unlike SM90, SM100 tile shape calculation includes cluster shape. - tile_shape = [ - math_inst.instruction_shape[0] * cluster_multiplier[0], - math_inst.instruction_shape[1] * cluster_multiplier[1], - math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] - ] - warp_count = [4, 1, 1] - tile_description = TileDescription( - tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, - cluster_shape) - tile_descriptions.append(tile_description) - - # It's typical to get the data types from the math instruction. - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : output_type, - "d_type" : output_type, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } - - dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] - - # Schedules - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 - epilogue_schedule = EpilogueScheduleType.ScheduleAuto - schedule_pairs = [ - (mainloop_schedule, epilogue_schedule) - ] - - for conv_kind in conv_kinds: - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = tile_descriptions, - data_types = data_type, - schedule_pairs = schedule_pairs, - conv_kind = conv_kind, - log_indent_level = log_indent_level) - -def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, - log_indent_level: int = 0): - # Instantiate Fp8 Fprop kernels with e4m3 A/B, f32 Acc, e4m3/bf16/f16/f32 C/D - log_debug_line('GenerateSM100_TensorOp_fp8_UMMA_conv3x', log_indent_level) - log_indent_level = log_indent_level + 1 - - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - thor_sm = ThorSMRenumbering(cuda_version) - - minimum_compute_capability = 100 - maximum_compute_capability = thor_sm - - spatial_dims = [2, 3] - stages = 0 # zero means "deduce the number of stages automatically" - - data_types_and_instruction_shapes_1sm = [ - # ((A,B,Acc,C/D), (InstM,InstN,InstK)) - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (64, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (64, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (64, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (64, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), - ] - math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), - data_types_and_instruction_shapes_1sm) - - cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] - - for math_inst, output_type in math_instructions_w_output_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - cluster_multiplier = cluster_shape - # Unlike SM90, SM100 tile shape calculation includes cluster shape. - tile_shape = [ - math_inst.instruction_shape[0] * cluster_multiplier[0], - math_inst.instruction_shape[1] * cluster_multiplier[1], - math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] - ] - warp_count = [4, 1, 1] - tile_description = TileDescription( - tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, - cluster_shape) - tile_descriptions.append(tile_description) - - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : output_type, - "d_type" : output_type, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } - - dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] - - # Schedules - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 - epilogue_schedule = EpilogueScheduleType.ScheduleAuto - schedule_pairs = [ - (mainloop_schedule, epilogue_schedule) - ] - - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = tile_descriptions, - data_types = data_type, - schedule_pairs = schedule_pairs, - conv_kind = ConvKind.Fprop, - log_indent_level = log_indent_level) - - data_types_and_instruction_shapes_2sm = [ - # ((A,B,Acc,C/D), (InstM,InstN,InstK)) - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (256, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (256, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (256, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), - ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (256, 256, 32)), - ] - math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), - data_types_and_instruction_shapes_2sm) - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] - - for math_inst, output_type in math_instructions_w_output_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - # Unlike SM90, SM100 tile shape calculation includes cluster shape. - tile_shape = [ - math_inst.instruction_shape[0] * cluster_multiplier[0], - math_inst.instruction_shape[1] * cluster_multiplier[1], - math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] - ] - warp_count = [4, 1, 1] - tile_description = TileDescription( - tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, - cluster_shape) - tile_descriptions.append(tile_description) - - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : output_type, - "d_type" : output_type, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } - - dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] - - # Schedules - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 - epilogue_schedule = EpilogueScheduleType.ScheduleAuto - schedule_pairs = [ - (mainloop_schedule, epilogue_schedule) - ] - - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = tile_descriptions, - data_types = data_type, - schedule_pairs = schedule_pairs, - conv_kind = ConvKind.Fprop, - log_indent_level = log_indent_level) - -def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): - # SM120 MMA with mixed F4/F6/F8 inputs + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - layouts = [ - [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] - ] - - instruction_sizes = [ - [16, 8, 32] - ] - - tile_sizes = [ - [128, 128, 128] - ] - - cluster_shape = [1,1,1] - - ab_types = [ - DataType.e2m1, - DataType.e2m3, - DataType.e3m2, - DataType.e5m2, - DataType.e4m3, - ] - - acc_types = [ DataType.f32 ] - - def is_pingpong(kernel_schedule): - if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: - return True - else: - return False - - def tile_schedulers(sfdtype, kernel_schedule): - # Pingpong kernel schedule doesn't support stream-K. - # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, - # the epilogue is the traditional linear combination, for which we already have tests with stream-K - if is_pingpong(kernel_schedule): - return [TileSchedulerType.Default] - elif sfdtype["type"] == DataType.void: - return [TileSchedulerType.Default] - else: - return [TileSchedulerType.Default, TileSchedulerType.StreamK] - - min_cc = 120 - max_cc = 121 - - epi_type = DataType.f32 - - math_instructions = [] - - kernel_schedules = [ - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 - ] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types): - math_instructions.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - for math_inst in math_instructions: - tile_descriptions = [] - for tile_size in tile_sizes: - tile_descriptions.append( - TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e3m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type, kernel_schedule in product(data_types, kernel_schedules): - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], - tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), - gemm_kind = GemmKind.BlockScaledUniversal3x - ) - -def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): - # SM120 MMA with with F4 + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]] - ] - - instruction_sizes = [ - [16, 8, 64] - ] - - tile_sizes_cooperative = [ - [128, 128, 128], - [128, 128, 256], - [256, 128, 128] - ] - - tile_sizes_pingpong = [ - [128, 128, 128], - [128, 128, 256] - ] - - cluster_shape = [1,1,1] - - ab_types = [ - DataType.e2m1 - ] - - sf_types = [ - DataType.ue4m3, - DataType.ue8m0 - ] - - acc_types = [ DataType.f32 ] - - def is_pingpong(kernel_schedule): - if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \ - kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: - return True - else: - return False - - def is_nvf4(kernel_schedule): - if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \ - kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: - return True - else: - return False - - def tile_schedulers(sfdtype, kernel_schedule): - # Pingpong kernel schedule doesn't support stream-K. - # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, - # the epilogue is the traditional linear combination, for which we already have tests with stream-K - if is_pingpong(kernel_schedule): - return [TileSchedulerType.Default] - elif sfdtype["type"] == DataType.void: - return [TileSchedulerType.Default] - else: - return [TileSchedulerType.Default, TileSchedulerType.StreamK] - - min_cc = 120 - max_cc = 121 - - epi_type = DataType.f32 - - math_instructions = [] - - kernel_schedules = [ - KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, - KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, - KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, - KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 - ] - - for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types): - math_instructions.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - sf_type) - ) - - for math_inst in math_instructions: - for kernel_schedule in kernel_schedules: - tile_descriptions = [] - tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative - for tile_size in tile_sizes: - # nvf4 kernel only supports ue4m3 SF - # mxf4 kernel only supports ue8m0 SF - if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \ - (math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)): - tile_descriptions.append( - TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e2m1, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} - } - ] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], - tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), - gemm_kind = GemmKind.BlockScaledUniversal3x - ) - -def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - layouts = [ - [[LayoutType.RowMajor, 256], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] - ] - - tile_sizes = [ - [128, 128, 256] - ] - - cluster_shape = [1,1,1] - - warp_count = [4, 2, 1] - - acc_types = [ DataType.f32 ] - - instruction_sizes_mxf8f6f4 = [ - [16, 8, 64] - ] - - ab_types_mxf8f6f4 = [ - DataType.e2m1, - #DataType.e2m3, - DataType.e3m2, - #DataType.e5m2, - DataType.e4m3, - ] - - def tile_schedulers(kernel_schedule): - return [TileSchedulerType.Default] - - min_cc = 120 - max_cc = 121 - - kernel_schedules = [ - KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120, - ] - - math_instructions_mxf8f6f4 = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_mxf8f6f4, ab_types_mxf8f6f4, ab_types_mxf8f6f4, acc_types): - math_instructions_mxf8f6f4.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add) - ) - - # Create gemm operator for mxf8f6f4 - for math_inst in math_instructions_mxf8f6f4: - tile_descriptions_mxf8f6f4 = [] - for tile_size in tile_sizes: - tile_descriptions_mxf8f6f4.append( - TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f32, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.e5m2, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - } - ] - - for data_type, kernel_schedule in product(data_types, kernel_schedules): - # Set alignment d based on Destination format - for layout in layouts: - layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) - # Create gemm operator - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_mxf8f6f4, data_type, - [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], - tile_schedulers = tile_schedulers(kernel_schedule), - gemm_kind = GemmKind.SparseUniversal3x) - -def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - layouts = [ - [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]], - [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]] - ] - - cooperative_tile_sizes = [ - [128, 128, 128] - ] - pingpong_tile_sizes = [ - [64, 128, 128] - ] - - def get_tile_sizes(kernel_scheduler): - if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: - return pingpong_tile_sizes - return cooperative_tile_sizes - - def get_warp_count(kernel_scheduler): - if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: - return [2, 2, 1] - return [4, 2, 1] - - def get_sf_sizes(tile_size): - sf_sizes = [] - for vec_m in [1, 128]: - if tile_size[0] % vec_m > 0: - continue - for vec_n in [1, 128]: - if tile_size[1] % vec_m > 0: - continue - sf_sizes.append( - [vec_m, vec_n, 128] - ) - return sf_sizes - - cluster_shape = [1,1,1] - - acc_types = [ DataType.f32 ] - - instruction_sizes = [ - [16, 8, 32] - ] - - def tile_schedulers(kernel_schedule): - return [TileSchedulerType.Default] - - min_cc = 120 - max_cc = 121 - - kernel_schedulers = [ - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120 - ] - - ab_types = [ - [DataType.e4m3, DataType.e4m3], - [DataType.e4m3, DataType.e5m2] - ] - - math_instructions = [] - - for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types): - a_type, b_type = ab_type - math_instructions.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - # Create gemm operator for mxf8f6f4 - for kernel_schedule in kernel_schedulers: - tile_sizes = get_tile_sizes(kernel_schedule) - warp_count = get_warp_count(kernel_schedule) - for math_inst in math_instructions: - tile_descriptions = [] - for tile_size in tile_sizes: - sf_sizes = get_sf_sizes(tile_size) - for sf_size in sf_sizes: - tile_descriptions.append( - TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape, - explicit_vector_sizes=sf_size) - ) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.bf16, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - }, - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.bf16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - } - ] - - for data_type in data_types: - # Set alignment d based on Destination format - for layout in layouts: - layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) - # Create gemm operator - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], - tile_schedulers = tile_schedulers(kernel_schedule), - gemm_kind = gemm_kind) - -def GenerateSM100(manifest, cuda_version): - arch_family_cc = ['100f', '101f', '103a'] - if CudaToolkitVersionSatisfies(cuda_version, 13, 0): - for old_cc, new_cc in [('101f', '110f')]: - arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc] - - # - # Dense Gemm - # - GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) - - GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) - - if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): - GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) - - GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) - # grouped GEMM - GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - - # StreamK is included in regular generation - GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) - - # Blockwise kernels - GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) - GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) - - # - # Sparse Gemm - # - GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version) - GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version) - if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): - GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version) - GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version) - GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) - - # - # Block Scaled Gemm - # - GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) - GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) - GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) - GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) - - GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) - GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) - # - # Conv - # - GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version) - GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version) - - -def GenerateSM120(manifest, cuda_version): - # StreamK is included in regular generation # - # - # Dense Block Scaled Gemm - # - GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) - GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) - - # - # Sparse Gemm - # - GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version) - GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) - GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) - -################################################################################################### - -def GenerateSM90_Conv3x(manifest, cuda_version, - log_indent_level: int = 0): - """ - Generate CUTLASS 3 convolution kernel(s) for SM90. - - This is meant to be called from GenerateSM90. - """ - log_debug_line('GenerateSM90_Conv3x', log_indent_level) - log_indent_level = log_indent_level + 1 - - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - minimum_compute_capability = 90 - maximum_compute_capability = 90 - - spatial_dims = (2, 3) - - # MMA shapes (MMA_M, MMA_N, MMA_K): - # - # Different hardware MMA instructions may have different MMA shapes. - # This function may generate kernels with different MMA shapes for - # different data types, either because the hardware only supports - # certain shapes for certain types, or for performance reasons - # (CUTLASS doesn't need to generate all valid kernels for the - # profiler library, just the best-performing ones). - # - # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) - # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, - # where 4, the "number of MMA instructions per tile," is determined - # through some combination of modeling and experiment. - # - # For performance on sm90, generally CUTLASS generates 64x128 - # instead of 128x64. - mma_64x64x16 = ( 64, 64, 16) - mma_64x64x8 = ( 64, 64, 8) - - num_mma_per_tile = 4 - - # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, - # but not included, because they tend not to perform as well. - cluster_shapes = ( - (2, 1, 1), - (1, 2, 1), - ) - - fp16 = DataType.f16 - bf16 = DataType.bf16 - fp32 = DataType.f32 - s8 = DataType.s8 - s32 = DataType.s32 - - # When generating kernels, the usual way is to specify 4 types, - # (A, B, Acc, C/D). Tests instead have 5 types, - # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), - # where ElementCompute is also called 'epi_type', - # and corresponds to the type of epilogue activations. - # This script maps tests' 5 types to 4 types - # by making ElementCompute the same as ElementOut. - - fp16_fp32_fp16_fp32 = { - 'a_type': fp16, # ElementAct(ivation) - 'b_type': fp16, # ElementF(i)lt(er) - 'c_type': fp32, # ElementAcc - 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) - 'acc_type': fp16, # ElementAcc - 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) - 'alignment_A': 8, # tma alignment elements of A - 'alignment_B': 8, # tma alignment elements of B - 'alignment_C': 4, # tma alignment elements of C - } - fp16_fp32_fp32_fp32 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - 'alignment_A': 8, - 'alignment_B': 8, - 'alignment_C': 4, - } - fp32_fp32_fp32_fp32 = { - 'a_type': fp32, - 'b_type': fp32, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - 'alignment_A': 4, - 'alignment_B': 4, - 'alignment_C': 4, - } - s8_s32_s32_s32 = { - 'a_type': s8, - 'b_type': s8, - 'c_type': s32, - 'd_type': s32, - 'acc_type': s32, - 'epi_type': s32, - 'alignment_A': 16, - 'alignment_B': 16, - 'alignment_C': 4, - } - - # Other NVIDIA libraries may have the habit of specifying data types like this. - bf16bf16_bf16f32_f32 = { - 'a_type': bf16, - 'b_type': bf16, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - 'alignment_A': 8, - 'alignment_B': 8, - 'alignment_C': 4, - } - f16f16_f16f16_f16 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp16, - 'd_type': fp16, - 'acc_type': fp16, - 'epi_type': fp16, - 'alignment_A': 8, - 'alignment_B': 8, - 'alignment_C': 8, - } - f16f16_f16f32_f32 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp16, - 'd_type': fp16, - 'acc_type': fp32, - 'epi_type': fp32, - 'alignment_A': 8, - 'alignment_B': 8, - 'alignment_C': 8, - } - f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 - - i8i8_i8i32_f32 = { - 'a_type': s8, - 'b_type': s8, - 'c_type': s32, - 'd_type': s32, - 'acc_type': s32, - 'epi_type': s32, - 'alignment_A': 16, - 'alignment_B': 16, - 'alignment_C': 4, - } - - # Each element in the outermost iterable is one combination of - # - # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) - # - # for which to generate a kernel. spatial_dimension is the spatial - # dimension of the convolution: either 1, 2, or 3. byte_alignments - # is a triple of required minimum byte alignments for A, B, and C. - # - # Note that itertools functions produce a single-pass generator. - # The code doesn't need a multipass iterable, but if one did, one - # could call `tuple` or `list` on the generator. - # - # While this happens to use the same cluster sizes for each element, - # the code doesn't require that. Different convolution kinds, data - # types, or mma sizes might have different optimal cluster sizes. - combinations_of_parameters = chain( - # The following are all the kernels exercised in the unit tests. - # Please try to keep in sync with the unit tests. - product( - ( - ConvKind.Fprop, - ), - spatial_dims, - ( - fp16_fp32_fp16_fp32, - fp16_fp32_fp32_fp32, - s8_s32_s32_s32, - ), - ( - mma_64x64x16, - ), - cluster_shapes - ), - product( - ( - ConvKind.Fprop, - ), - spatial_dims, - ( - fp32_fp32_fp32_fp32, - ), - ( - mma_64x64x8, - ), - cluster_shapes - ), - product( - ( - ConvKind.Dgrad, - ConvKind.Wgrad - ), - spatial_dims, - ( - fp16_fp32_fp16_fp32, - fp16_fp32_fp32_fp32, - ), - ( - mma_64x64x16, - ), - cluster_shapes - ), - # Kernels not necessarily in the unit tests, but used elsewhere - # and thus useful to have generated for profiling. They may - # duplicate kernels above. All of them are 2-D. In general, - # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the - # hardware permits 128 x 64. - ( - # Fprop - # - # bf16bf16_bf16f32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), - # - # f16f16_f16f16_f16 - # - # cluster shape (1, 1, 1) - # - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), - # - # f16f16_f16f32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), - # - # f32f32_tf32f32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 192, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 256, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 128, 8), (2, 1, 1)), - (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 96, 8), (2, 1, 1)), - # - # i8i8_i8i32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 32), (2, 1, 1)), - (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 16), (2, 1, 1)), - (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 32), (2, 1, 1)), - # - # Dgrad - # - # bf16bf16_bf16f32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), - (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), - (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), - (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), - # - # f16f16_f16f16_f16 - # - # cluster shape (1, 1, 1) - # - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), - # - # f16f16_f16f32_f32 - # - # cluster shape (2, 1, 1) - # - (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), - (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), - ), - ) - - # SM >= 90 kernels don't actually use warp_count, but the - # TileDescription class needs it. The 4 in the default - # warp_count has nothing to do with num_mma_per_tile. - warp_count = [4, 1, 1] - - stages = 0 # zero means "deduce the number of stages automatically" - - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 - epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized - schedule_pairs = ( - (mainloop_schedule, epilogue_schedule), - ) - tile_schedulers = ( - TileSchedulerType.Default, # -> void - ) - - def make_math_instruction(data_types: Dict[str, DataType], - mma_shape: Tuple[int, int, int]) -> MathInstruction: - default_opcode = OpcodeClass.TensorOp - default_math_op = MathOperation.multiply_add - return MathInstruction( - mma_shape, - data_types['a_type'], data_types['b_type'], data_types['c_type'], - default_opcode, - default_math_op - ) - - for (conv_kind, spatial_dim, data_types, mma_shape, cluster_shape) in combinations_of_parameters: - math_inst = make_math_instruction(data_types, mma_shape) - tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) - tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, cluster_shape) - assert(isinstance(spatial_dim, int)) - dims_and_alignments = ( - ( - (spatial_dim, data_types['alignment_A']), - (spatial_dim, data_types['alignment_B']), - (spatial_dim, data_types['alignment_C']), - ), - ) - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = [tile_description], - data_types = data_types, - schedule_pairs = schedule_pairs, - tile_schedulers = tile_schedulers, - conv_kind = conv_kind, - log_indent_level = log_indent_level) - -def GenerateSM90(manifest, cuda_version): - GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_1684(manifest, cuda_version) - GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) - GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) - GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) - GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) - GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) - GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version) - GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version) - GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version) - GenerateSM90_TensorOp_1684_symm(manifest, cuda_version) - GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version) - GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version) - GenerateSM90_Conv3x(manifest, cuda_version) - GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version) - GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version) - GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) - -################################################################################################### - -def numeric_log_level(log_level: str) -> int: - """ - Converts the string identifier of the log level - into the numeric identifier used in setting the log level. - - :param x: string representation of log level (e.g., 'INFO', 'DEBUG') - :type x: str - - :return: numeric representation of log level - :rtype: int - """ - numeric_level = getattr(logging, log_level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError(f'Invalid log level: {log_level}') - return numeric_level - -# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface -# to leverage the functionality in this file without running this script via a shell prompt. -def define_parser(): - parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") - parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") - parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") - parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") - parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") - parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + - 'Specifying this as \"all\" includes ALL the kernels, ' + - 'while not specifying this includes only the default set of kernels.') - parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' + - 'to exclude from build. For backwards compatibility reasons, ' + - 'this option only takes effect if --kernels is set to a nonempty value.') - parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' + - 'to exclude from build. In contrast to --ignore-kernels, ' + - 'this option always takes effect, ' + - 'whether or not --kernels is set to a nonempty value. ' + - 'It also can exclude kernels from the filter file ' + - '(see --kernel-filter-file option below).') - parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') - parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") - parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') - parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list') - parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler') - parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000']) - parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list') - parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py') - parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, - help='Specify the output log file containing all enabled kernels in this build') - parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") - parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") - parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, - help='Logging level to be used by the generator script') - parser.add_argument('--instantiation-level', type=str, default="", required=False, help="Instantiation level for SM90 kernels. Set to `max` and make sure `--kernels` is not empty to generate all possible configurations.") - _add_package_disablement_flag(parser) - return parser - - -if __name__ == "__main__": - parser = define_parser() - args = parser.parse_args() - - # Set the logging level based on the user-provided `--log-level` command-line option - logging.basicConfig(level=args.log_level) - - manifest = Manifest(args) - - archs = args.architectures.split(';') - - if args.heuristics_problems_file: - filter_manifest_and_write_heuristics_file(manifest, args) - - GenerateSM50(manifest, args.cuda_version) - GenerateSM60(manifest, args.cuda_version) - GenerateSM61(manifest, args.cuda_version) - GenerateSM70(manifest, args.cuda_version) - GenerateSM75(manifest, args.cuda_version) - GenerateSM80(manifest, args.cuda_version) - GenerateSM89(manifest, args.cuda_version) - GenerateSM90(manifest, args.cuda_version) - - blackwell_arch_list = [ - "100a", "100f", - "101a", "101f", - "103a", "103f", - "110a", "110f", - "120a", "120f", - "121a", "121f", - ] - blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) - if blackwell_enabled_arch: - GenerateSM100(manifest, args.cuda_version) - GenerateSM120(manifest, args.cuda_version) - - if 'library' in args.generator_target.split(','): - manifest.emit(GeneratorTarget.Library) - - if 'kernel_testlist_l0' in args.generator_target.split(','): - emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0") - - if 'kernel_testlist_l1' in args.generator_target.split(','): - emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1") - - if args.selected_kernel_list is not None: - if len(manifest.selected_kernels) > 0: - with open(args.selected_kernel_list, 'w') as file_writer: - for line in manifest.selected_kernels: - file_writer.write("%s\n" % line) - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py deleted file mode 100644 index 83421a06427acdc3b059855991cf95a1d2f118b3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py +++ /dev/null @@ -1,415 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for selecting CUTLASS library kernels based on problem description -""" -import json -import csv - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * - from cutlass_library.generator import * - from cutlass_library.heuristics_provider import * -except ImportError: - from library import * - from generator import * - from heuristics_provider import * - -try: - from .sm90_utils import ( - get_valid_schedules, - generate_data_types_from_math_instruction, - fix_alignments, - ) -except ImportError: - from sm90_utils import ( - get_valid_schedules, - generate_data_types_from_math_instruction, - fix_alignments, - ) - -_LOGGER = logging.getLogger(__name__) - -dtype_map = {v: k for k, v in DataTypeNames.items()} - -def serialize_heuristics_results_to_json(problems_with_configs, outfile_path): - """ - Utilitiy function to write heuristics results to a json file for debug - - args: - problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict - outfile_path: Outfile path - - returns: - None - """ - pc_copy = problems_with_configs.copy() - for p in pc_copy: - for k, v in p.items(): - if isinstance(v, DataType): - p[k] = DataTypeNames[v] - elif isinstance(v, LayoutType): - p[k] = ShortLayoutTypeNames[v] - configs = p['configs'] - for c in configs: - for k, v in c.items(): - if isinstance(v, DataType): - c[k] = DataTypeNames[v] - elif isinstance(v, LayoutType): - c[k] = ShortLayoutTypeNames[v] - with open(outfile_path, 'w') as f: - json.dump(pc_copy, f, indent=2) - -def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None): - """ - Get heuristic-suggested GEMM kernel configurations for a single GEMM problem. - - args: - m, n, k: GEMM dimensions - batch_count: batch count - layouts: tuple of layouts of type LayoutType - use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions - count: Number of configs to return - provider: Heuristics provider to use - - returns: - A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys: - - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size - - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size - - 'stages': kernel pipeline stage count - - 'cluster_m', 'cluster_n', 'cluster_k': cluster size - - 'layout_a', 'layout_b': input tensor layouts of type LayoutType - - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements - - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType - - 'swizzle_size' : suggested threadblock swizzle - - 'split_k_slices': number of partitions of the k dimension for splitK - - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n') - """ - if provider is None: - provider = MatmulHeuristics() - return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count) - -def get_gemm_configs(problems, provider=None, count=1): - """ - Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems. - - args: - problems: List of dictionaries describing GEMM problems with the following keys: - - 'm', 'n', 'k': Matrix dimensions (required) - - 'dtype_a': Data type of matrix A (required) - - 'dtype_b': Data type of matrix B (required) - - 'dtype_c': Data type of matrix C (default: None) - - 'dtype_d': Data type of matrix D (required) - - 'dtype_acc': Compute data type (default 'f32') - - 'layout': Operation layout (e.g. 'tnt') - - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements) - - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements) - - 'alpha': Scalar multiplier for A*B (default: 1.0) - - 'beta': Scalar multiplier for C (default: 0.0) - - 'batch_count': Number of GEMM operations in batch (default: 1) - - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True) - provider: Heuristics provider to use - count: Number of configurations to return per problem (defualt: 1) - - returns: - A copy of the input dictionary, with key `configs` added containing the selected gemm configs - """ - ret = [] - - for problem in problems: - problem = problem.copy() - - try: - m = problem['m'] - n = problem['n'] - k = problem['k'] - dtype_a = problem['dtype_a'] - dtype_b = problem['dtype_b'] - dtype_d = problem['dtype_d'] - layout = problem['layout'] - except KeyError as e: - _LOGGER.error(f"Missing required parameter {e} for problem {problem}") - raise - - operation = problem.get('operation', 'gemm') - batch_count = problem.get('batch_count', 1) - dtype_acc = problem.get('dtype_acc', 'f32') - dtype_c = problem.get('dtype_c', None) - alpha = problem.get('alpha', 1.0) - beta = problem.get('beta', 0.0) - use_fast_acc = problem.get('use_fast_acc', True) - - if operation != OperationKindNames[OperationKind.Gemm]: - raise ValueError(f"Unsupported operation {operation}") - if not (len(layout) == 3 and all(c in "nt" for c in layout)): - raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}") - layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout) - - try: - dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()] - dtypes = tuple(dtype_map[dt] for dt in dtype_list) - except KeyError as dt: - _LOGGER.error(f"Unsupported data type: {dt}") - raise - - alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]]) - alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]]) - - configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider) - problem['configs'] = configs - - ret.append(problem) - - return ret - - -def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs): - """ - Generate CUTLASS operations based on the list of configs provided by the heuristic provider - - args: - manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) - cuda_version: Cuda compiler version for generating cutlass operations - kernel_configs: list of configs generated by the heuristic - - returns: - (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations - """ - min_cc = 100 - max_cc = 101 - if manifest is None: - # Use a dummy manifest so we can use existing CreateGemmOperator functions - manifest = Manifest() - - configs = [] - operations = [] - for config in kernel_configs: - layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]]) - element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] - - # nvMMH assumes 2sm instruction for !(cluster_m % 2) - is_2sm = config['cluster_m'] % 2 == 0 - instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4] - math_instruction = MathInstruction( - instruction_shape, - element_a, element_b, element_accumulator, - OpcodeClass.TensorOp, - MathOperation.multiply_add - ) - - data_types = [ - { - "a_type" : math_instruction.element_a, - "b_type" : math_instruction.element_b, - "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator, - "d_type" : element_d, - "acc_type" : math_instruction.element_accumulator, - "epi_type" : math_instruction.element_accumulator, - } - ] - - tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k']) - tile_description = TileDescription( - [instruction_shape[0] * tile_multiplier[0], - instruction_shape[1] * tile_multiplier[1], - instruction_shape[2] * 4 * tile_multiplier[2]], - 0, - [4,1,1], - math_instruction, - min_cc, - max_cc, - cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) - ) - - schedules = [] - if is_2sm: - schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]) - else: - schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]) - - for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x): - configs.append(config) - operations.append(o) - - - return configs, operations - - -def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs): - """ - Generate CUTLASS operations based on the list of configs provided by the heuristic provider - - args: - manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) - cuda_version: Cuda compiler version for generating cutlass operations - kernel_configs: list of configs generated by the heuristic - - returns: - (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations - """ - min_cc, max_cc = 90, 90 - - if manifest is None: - # Use a dummy manifest so we can use existing CreateGemmOperator functions - manifest = Manifest() - - configs = [] - operations = [] - for config in kernel_configs: - - is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128) - layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1]) - element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] - - # instr shape and warp config are unused for emitting 3x collective builder code - dummy_instr_shape = [0, 0, 0] - math_instruction = MathInstruction( - dummy_instr_shape, - element_a, element_b, element_accumulator, - OpcodeClass.TensorOp, - MathOperation.multiply_add - ) - - data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d) - if is_aligned: - layout = fix_alignments(data_types, layout, alignment_bits=128) - - # instr shape and warp config are unused for emitting 3x collective builder code - dummy_warp_count = [0, 0, 0] - tile_description = TileDescription( - [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']], - 0, - dummy_warp_count, - math_instruction, - min_cc, - max_cc, - cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) - ) - - schedules, stream_k_schedules = get_valid_schedules( - tile_description=tile_description, - cuda_version=cuda_version, - is_aligned=is_aligned, - data_types=data_types, - instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic - layout=layout, - gemm_kind=GemmKind.Universal3x, - enable_fp8_fast_acc=config['use_fast_acc'] - ) - - if len(schedules): - for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x): - configs.append(config) - operations.append(o) - - if len(stream_k_schedules): - for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, - stream_k_schedules, - tile_schedulers=[TileSchedulerType.StreamK]): - configs.append(config) - operations.append(o) - - - return configs, operations - -def filter_manifest_and_write_heuristics_file(manifest, args): - """ - Prune a manifest according to heuristics suggestions from the problems file - - args: - manifest: Cutlass manifest to prune - args: generator.py args, requires: - - args.heuristics_problems_file - - args.heuristics_gpu - - args.heuristics_testlist_file - - returns: - A list of dictionaries, each of which has information about an operation and a problem from the input problems - """ - heuristics_problems = [] - with open(args.heuristics_problems_file, 'r') as f: - heuristics_problems = json.load(f) - gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu - mmh = MatmulHeuristics(gpu=gpu) - if any(('100' in arch) for arch in args.architectures.split(';')): - mmh.set_cta_div_n(64) - problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem) - - all_configs_and_operations = [] - operations = [] - for problem in problems_with_configs: - if any('90' in arch for arch in args.architectures.split(';')): - problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) - if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')): - problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) - - operations += problem_operations - problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'} - with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)] - all_configs_and_operations += with_problem_size - - for operation in operations: - manifest.add_kernel_filter(f"^{operation.procedural_name()}$") - if not all_configs_and_operations: - raise Exception("No valid configurations generated") - write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file) - return all_configs_and_operations - -def write_profiler_testlist_to_csv(configs_list, outfile_path): - """ - Write a list of configs to a testlist to be consumed by cutlass_profiler - - args: - configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries - outfile_path: Outfile path - - returns: - None - """ - profiler_testlist = configs_list.copy() - for c in profiler_testlist: - for k, v in c.items(): - if isinstance(v, DataType): - c[k] = DataTypeNames[v] - elif isinstance(v, LayoutType): - c[k] = ShortLayoutTypeNames[v] - - with open(outfile_path, mode='w', newline='') as ofile: - k_names = profiler_testlist[0].keys() - - writer = csv.DictWriter(ofile, fieldnames=k_names) - writer.writeheader() - writer.writerows(profiler_testlist) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py deleted file mode 100644 index 01a4112a34c87d73a792cce368fede96a9315ac1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py +++ /dev/null @@ -1,175 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Providers for kernel selection heuristics -""" - -import sys -import os -import glob -import logging -import ctypes -import functools - - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import DataType, LayoutType -except ImportError: - from library import DataType, LayoutType - -class MatmulHeuristics: - - def __init__(self, gpu = None): - import nvMatmulHeuristics - self.mmh_lib = nvMatmulHeuristics - self.gpu = gpu - - if 'CUTLASS_NVMMH_SO_PATH' in os.environ: - nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH']) - else: - nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx - - self.lh = nvmmhInterfaceEx( - backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"], - flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING, - load_discovery_implicitly=True, - gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None - ) - self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"]) - - def _layout_from_cutlass(self, layouts): - assert(len(layouts)==3) - full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts) - input_layouts = full_layout_str[:2].upper() - lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR") - return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout] - - def _precision_from_cutlass_dtypes(self, dtypes): - dtype_to_cublas = { - DataType.f64: 'D', - DataType.f32: 'S', - DataType.f16: 'H', - DataType.bf16: 'T', - DataType.e4m3: 'Q', - DataType.e5m2: 'R', - DataType.s32: 'I', - DataType.s8: 'B', - } - - dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes - - a_c = dtype_to_cublas[dtype_a] - - if a_c.lower() != 'q': - return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] - else: - return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] - - def set_cta_div_n(self, div_n): - cta_n_div_requirement = ctypes.c_int(div_n) - self.lh.setBackendValueProperty( - self.backend, - self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, - ctypes.byref(cta_n_div_requirement), - ctypes.sizeof(cta_n_div_requirement) - ) - - def set_cta_div_m(self, div_m): - cta_m_div_requirement = ctypes.c_int(div_m) - self.lh.setBackendValueProperty( - self.backend, - self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, - ctypes.byref(cta_m_div_requirement), - ctypes.sizeof(cta_m_div_requirement) - ) - - def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1): - if use_fast_acc: - disable_fast_acc_for_fp8 = ctypes.c_int(0) - else: - disable_fast_acc_for_fp8 = ctypes.c_int(1) - self.lh.setBackendValueProperty( - self.backend, - self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8, - ctypes.byref(disable_fast_acc_for_fp8), - ctypes.sizeof(disable_fast_acc_for_fp8) - ) - - precision = self._precision_from_cutlass_dtypes(dtypes) - layout = self._layout_from_cutlass(layouts) - - matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count) - configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision) - - ret = [] - for c in configs: - kernel = c['kernel'] - problem = c['problem'] - - r = {} - r['estimated_runtime'] = c['runtime'] - r['cta_tile_m'] = kernel.cta_tile_m - r['cta_tile_n'] = kernel.cta_tile_n - r['cta_tile_k'] = kernel.cta_tile_k - r['instr_tile_m'] = kernel.instr_tile_m - r['instr_tile_n'] = kernel.instr_tile_n - r['instr_tile_k'] = kernel.instr_tile_k - r['warp_tile_m'] = kernel.warp_tile_m - r['warp_tile_n'] = kernel.warp_tile_n - r['warp_tile_k'] = kernel.warp_tile_k - r['cluster_m'] = kernel.cluster_m - r['cluster_n'] = kernel.cluster_n - r['cluster_k'] = 1 - r['layout_a'] = layouts[0] - r['layout_b'] = layouts[1] - r['layout_d'] = layouts[2] - r['dtype_a'] = dtypes[0] - r['dtype_b'] = dtypes[1] - r['dtype_acc'] = dtypes[2] - r['dtype_c'] = dtypes[3] - r['dtype_d'] = dtypes[4] - r['alignment_a'] = align_a - r['alignment_b'] = align_b - r['swizzle_size'] = kernel.swizzle_factor - r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n' - r['split_k_slices'] = kernel.split_k - r['use_fast_acc'] = use_fast_acc - r['voidC'] = voidC - - ret.append(r) - - return ret - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py deleted file mode 100644 index 56d22dc4b0705b4813b15b1b09decf53b38f7f37..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py +++ /dev/null @@ -1,1531 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Data types and tags used for emitting CUTLASS C++ kernels -""" - -import enum -import re - -# The following block implements enum.auto() for Python 3.5 variants that don't include it such -# as the default 3.5.2 on Ubuntu 16.04. -# -# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility - -try: - from enum import auto as enum_auto -except ImportError: - __cutlass_library_auto_enum = 0 - def enum_auto() -> int: - global __cutlass_library_auto_enum - i = __cutlass_library_auto_enum - __cutlass_library_auto_enum += 1 - return i - -################################################################################################### - -# -class GeneratorTarget(enum.Enum): - Library = enum_auto() -# -GeneratorTargetNames = { - GeneratorTarget.Library: 'library' -} -# - -################################################################################################### - -# -class DataType(enum.Enum): - void = enum_auto() # primarily used to disable C tensor for epilogues - b1 = enum_auto() - u2 = enum_auto() - u4 = enum_auto() - u8 = enum_auto() - u16 = enum_auto() - u32 = enum_auto() - u64 = enum_auto() - s2 = enum_auto() - s4 = enum_auto() - s8 = enum_auto() - s16 = enum_auto() - s32 = enum_auto() - s64 = enum_auto() - e4m3 = enum_auto() - e5m2 = enum_auto() - f8 = enum_auto() - f6 = enum_auto() - f4 = enum_auto() - e3m2 = enum_auto() - e2m3 = enum_auto() - e2m1 = enum_auto() - ue8m0 = enum_auto() - ue4m3 = enum_auto() - f16 = enum_auto() - bf16 = enum_auto() - f32 = enum_auto() - tf32 = enum_auto() - f64 = enum_auto() - cf16 = enum_auto() - cbf16 = enum_auto() - cf32 = enum_auto() - ctf32 = enum_auto() - cf64 = enum_auto() - cs2 = enum_auto() - cs4 = enum_auto() - cs8 = enum_auto() - cs16 = enum_auto() - cs32 = enum_auto() - cs64 = enum_auto() - cu2 = enum_auto() - cu4 = enum_auto() - cu8 = enum_auto() - cu16 = enum_auto() - cu32 = enum_auto() - cu64 = enum_auto() - invalid = enum_auto() - -# -ShortDataTypeNames = { - DataType.s32: 'i', - DataType.e4m3: 'e4m3', - DataType.e5m2: 'e5m2', - DataType.f16: 'h', - DataType.f32: 's', - DataType.f64: 'd', - DataType.cf32: 'c', - DataType.cf64: 'z', - DataType.f8: 'f8', - DataType.f6: 'f6', - DataType.f4: 'f4', -} - -# -DataTypeNames = { - DataType.void: "void", - DataType.b1: "b1", - DataType.u2: "u2", - DataType.u4: "u4", - DataType.u8: "u8", - DataType.u16: "u16", - DataType.u32: "u32", - DataType.u64: "u64", - DataType.s2: "s2", - DataType.s4: "s4", - DataType.s8: "s8", - DataType.s16: "s16", - DataType.s32: "s32", - DataType.s64: "s64", - DataType.e4m3: 'e4m3', - DataType.e5m2: 'e5m2', - DataType.f8: 'f8', - DataType.f6: 'f6', - DataType.f4: 'f4', - DataType.e2m3: 'e2m3', - DataType.e3m2: 'e3m2', - DataType.e2m1: 'e2m1', - DataType.ue8m0: 'ue8m0', - DataType.ue4m3: 'ue4m3', - DataType.f16: "f16", - DataType.bf16: "bf16", - DataType.f32: "f32", - DataType.tf32: "tf32", - DataType.f64: "f64", - DataType.cf16: "cf16", - DataType.cbf16: "cbf16", - DataType.cf32: "cf32", - DataType.ctf32: "ctf32", - DataType.cf64: "cf64", - DataType.cu2: "cu2", - DataType.cu4: "cu4", - DataType.cu8: "cu8", - DataType.cu16: "cu16", - DataType.cu32: "cu32", - DataType.cu64: "cu64", - DataType.cs2: "cs2", - DataType.cs4: "cs4", - DataType.cs8: "cs8", - DataType.cs16: "cs16", - DataType.cs32: "cs32", - DataType.cs64: "cs64", -} - -DataTypeTag = { - DataType.void: "void", - DataType.b1: "cutlass::uint1b_t", - DataType.u2: "cutlass::uint2b_t", - DataType.u4: "cutlass::uint4b_t", - DataType.u8: "uint8_t", - DataType.u16: "uint16_t", - DataType.u32: "uint32_t", - DataType.u64: "uint64_t", - DataType.s2: "cutlass::int2b_t", - DataType.s4: "cutlass::int4b_t", - DataType.s8: "int8_t", - DataType.s16: "int16_t", - DataType.s32: "int32_t", - DataType.s64: "int64_t", - DataType.e4m3: 'cutlass::float_e4m3_t', - DataType.e5m2: 'cutlass::float_e5m2_t', - DataType.f8: 'cutlass::type_erased_dynamic_float8_t', - DataType.f6: 'cutlass::type_erased_dynamic_float6_t', - DataType.f4: 'cutlass::type_erased_dynamic_float4_t', - DataType.e2m3: 'cutlass::float_e2m3_t', - DataType.e3m2: 'cutlass::float_e3m2_t', - DataType.e2m1: 'cutlass::float_e2m1_t', - DataType.ue8m0: 'cutlass::float_ue8m0_t', - DataType.ue4m3: 'cutlass::float_ue4m3_t', - DataType.f16: "cutlass::half_t", - DataType.bf16: "cutlass::bfloat16_t", - DataType.f32: "float", - DataType.tf32: "cutlass::tfloat32_t", - DataType.f64: "double", - DataType.cf16: "cutlass::complex", - DataType.cbf16: "cutlass::complex", - DataType.cf32: "cutlass::complex", - DataType.ctf32: "cutlass::complex", - DataType.cf64: "cutlass::complex", - DataType.cu2: "cutlass::complex", - DataType.cu4: "cutlass::complex", - DataType.cu8: "cutlass::complex", - DataType.cu16: "cutlass::complex", - DataType.cu32: "cutlass::complex", - DataType.cu64: "cutlass::complex", - DataType.cs2: "cutlass::complex", - DataType.cs4: "cutlass::complex", - DataType.cs8: "cutlass::complex", - DataType.cs16: "cutlass::complex", - DataType.cs32: "cutlass::complex", - DataType.cs64: "cutlass::complex", -} - -DataTypeSize = { - DataType.void: 0, - DataType.b1: 1, - DataType.u2: 2, - DataType.u4: 4, - DataType.u8: 8, - DataType.u16: 16, - DataType.u32: 32, - DataType.u64: 64, - DataType.s2: 2, - DataType.s4: 4, - DataType.s8: 8, - DataType.s16: 16, - DataType.s32: 32, - DataType.s64: 64, - DataType.e4m3: 8, - DataType.e5m2: 8, - DataType.f8: 8, - DataType.f6: 6, - DataType.f4: 4, - DataType.e2m3: 6, - DataType.e3m2: 6, - DataType.e2m1: 4, - DataType.ue8m0: 8, - DataType.ue4m3: 8, - DataType.f16: 16, - DataType.bf16: 16, - DataType.f32: 32, - DataType.tf32: 32, - DataType.f64: 64, - DataType.cf16: 32, - DataType.cbf16: 32, - DataType.cf32: 64, - DataType.ctf32: 32, - DataType.cf64: 128, - DataType.cu2: 4, - DataType.cu4: 8, - DataType.cu8: 16, - DataType.cu16: 32, - DataType.cu32: 64, - DataType.cu64: 128, - DataType.cs2: 4, - DataType.cs4: 8, - DataType.cs8: 16, - DataType.cs16: 32, - DataType.cs32: 64, - DataType.cs64: 128, -} - -################################################################################################### -# -class BlasMode(enum.Enum): - symmetric = enum_auto() - hermitian = enum_auto() - -# -BlasModeTag = { - BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', - BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', -} - -# -class ComplexTransform(enum.Enum): - none = enum_auto() - conj = enum_auto() - -# -ComplexTransformTag = { - ComplexTransform.none: 'cutlass::ComplexTransform::kNone', - ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', -} - -# Used for cutlass3x complex kernel collective mainloop builder instantiation -ComplexTransformTag3x = { - ComplexTransform.none: 'cute::identity', - ComplexTransform.conj: 'cute::conjugate', -} - -# -RealComplexBijection = [ - (DataType.f16, DataType.cf16), - (DataType.f32, DataType.cf32), - (DataType.f64, DataType.cf64), -] - -# -def is_complex(data_type): - for r, c in RealComplexBijection: - if data_type == c: - return True - return False - -def is_block_scaled(gemm_kind): - return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) - -def is_blockwise(gemm_kind): - return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) - -def is_grouped(gemm_kind): - return gemm_kind in (GemmKind.GroupedUniversal3x, - GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) - -# -def get_complex_from_real(real_type): - for r, c in RealComplexBijection: - if real_type == r: - return c - return DataType.invalid - -# -def get_real_from_complex(complex_type): - for r, c in RealComplexBijection: - if complex_type == c: - return r - return DataType.invalid - -# TMA requires an alignment of 128 bits for all data types -def get_tma_alignment(data_type): - if data_type == DataType.void: - return 0 - elif DataTypeSize[data_type] == 6: - return 128 # 96B alignment for 16U6 format - else: - return 128 // DataTypeSize[data_type] - -# -class ComplexMultiplyOp(enum.Enum): - multiply_add = enum_auto() - gaussian = enum_auto() - -################################################################################################### - -# -class MathOperation(enum.Enum): - multiply_add = enum_auto() - multiply_add_saturate = enum_auto() - multiply_add_mixed_input_upcast = enum_auto() - xor_popc = enum_auto() - and_popc = enum_auto() - multiply_add_fast_bf16 = enum_auto() - multiply_add_fast_f16 = enum_auto() - multiply_add_fast_f32 = enum_auto() - multiply_add_complex_fast_f32 = enum_auto() - multiply_add_complex = enum_auto() - multiply_add_complex_gaussian = enum_auto() - multiply_add_fast_accum = enum_auto() - -# -MathOperationTag = { - MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', - MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', - MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', - MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', - MathOperation.and_popc: 'cutlass::arch::OpAndPopc', - MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', - MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', - MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', - MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', - MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', - MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', - MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum', -} - -################################################################################################### - -# -class LayoutType(enum.Enum): - ColumnMajor = enum_auto() - RowMajor = enum_auto() - ColumnMajorInterleaved2 = enum_auto() - RowMajorInterleaved2 = enum_auto() - ColumnMajorInterleaved32 = enum_auto() - RowMajorInterleaved32 = enum_auto() - ColumnMajorInterleaved64 = enum_auto() - RowMajorInterleaved64 = enum_auto() - TensorNWC = enum_auto() - TensorNHWC = enum_auto() - TensorNDHWC = enum_auto() - TensorNCHW = enum_auto() - TensorNGHWC = enum_auto() - TensorNC32HW32 = enum_auto() - TensorNC64HW64 = enum_auto() - TensorC32RSK32 = enum_auto() - TensorC64RSK64 = enum_auto() - TensorKCS = enum_auto() - TensorKCSR = enum_auto() - TensorKCSRT = enum_auto() - -# -LayoutTag = { - LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', - LayoutType.RowMajor: 'cutlass::layout::RowMajor', - LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', - LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', - LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', - LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', - LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', - LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', - LayoutType.TensorNWC: 'cutlass::layout::TensorNWC', - LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', - LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', - LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', - LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', - LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', - LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', - LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', - LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', - LayoutType.TensorKCS: 'cutlass::layout::TensorKCS', - LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR', - LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT' -} - -# -TransposedLayout = { - LayoutType.ColumnMajor: LayoutType.RowMajor, - LayoutType.RowMajor: LayoutType.ColumnMajor, - LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, - LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, - LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, - LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, - LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, - LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, - LayoutType.TensorNHWC: LayoutType.TensorNHWC -} - -# -ShortLayoutTypeNames = { - LayoutType.ColumnMajor: 'n', - LayoutType.ColumnMajorInterleaved2: 'n2', - LayoutType.ColumnMajorInterleaved32: 'n32', - LayoutType.ColumnMajorInterleaved64: 'n64', - LayoutType.RowMajor: 't', - LayoutType.RowMajorInterleaved2: 't2', - LayoutType.RowMajorInterleaved32: 't32', - LayoutType.RowMajorInterleaved64: 't64', - LayoutType.TensorNWC: 'nwc', - LayoutType.TensorNHWC: 'nhwc', - LayoutType.TensorNDHWC: 'ndhwc', - LayoutType.TensorNCHW: 'nchw', - LayoutType.TensorNGHWC: 'nghwc', - LayoutType.TensorNC32HW32: 'nc32hw32', - LayoutType.TensorNC64HW64: 'nc64hw64', - LayoutType.TensorC32RSK32: 'c32rsk32', - LayoutType.TensorC64RSK64: 'c64rsk64', - LayoutType.TensorKCS: 'kcs', - LayoutType.TensorKCSR: 'kcsr', - LayoutType.TensorKCSRT: 'kcsrt' -} - -# -ShortComplexLayoutNames = { - (LayoutType.ColumnMajor, ComplexTransform.none): 'n', - (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', - (LayoutType.RowMajor, ComplexTransform.none): 't', - (LayoutType.RowMajor, ComplexTransform.conj): 'h' -} - -################################################################################################### -class KernelScheduleType(enum.Enum): - ScheduleAuto = enum_auto() - Multistage = enum_auto() - CpAsyncWarpSpecialized = enum_auto() - CpAsyncWarpSpecializedPingpong = enum_auto() - CpAsyncWarpSpecializedCooperative = enum_auto() - Tma = enum_auto() - TmaWarpSpecialized = enum_auto() - TmaWarpSpecializedPingpong = enum_auto() - TmaWarpSpecializedCooperative = enum_auto() - TmaWarpSpecializedFP8FastAccum = enum_auto() - TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() - TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() - ImplicitTmaWarpSpecializedSm90 = enum_auto() - PtrArrayTmaWarpSpecializedCooperative = enum_auto() - PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() - PtrArrayTmaWarpSpecializedPingpong = enum_auto() - PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() - - BlockwiseTmaWarpSpecializedCooperative = enum_auto() - PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() - BlockwiseTmaWarpSpecializedPingpong = enum_auto() - PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() - - TmaWarpSpecialized1SmSm100 = enum_auto() - TmaWarpSpecialized2SmSm100 = enum_auto() - ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() - ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() - - PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() - PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() - - PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() - PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() - PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() - PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() - PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() - PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() - PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() - PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - - SparseTmaWarpSpecialized1SmSm100 = enum_auto() - SparseTmaWarpSpecialized2SmSm100 = enum_auto() - - BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() - BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() - Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() - Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - - BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() - BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() - - PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() - PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() - - - Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() - Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() - Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() - Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() - - # FP4 Ultra - MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() - MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() - - MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() - - MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() - MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() - - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() - - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() - - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() - PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() - - Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() - Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() - Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() - Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() - Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() - Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() - - F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() - - BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() - BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() - -KernelScheduleTag = { - KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', - KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', - KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', - KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', - KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', - KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', - KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', - KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', - KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', - KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', - KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', - - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', - - KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', - KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', - - KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100', - KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100', - - KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100', - KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100', - - KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100', - KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100', - - KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', - KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', - - KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100', - KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100', - - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100', - - KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', - KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', - KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', - KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', - - # FP4 Ultra - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', - - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - - KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', - KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', - KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', - KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', - - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', - - KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", - KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", - KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", - KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", - KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", - KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", - KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", - KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", - - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', - KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', - KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120', - KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120', - KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120', - - KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120', - - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', -} - -# -KernelScheduleSuffixes = { - KernelScheduleType.ScheduleAuto: '', - KernelScheduleType.Multistage: '_cpasync', - KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', - KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', - KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', - KernelScheduleType.Tma: '_unspecialized', - KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', - KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', - KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', - KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', - - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', - - KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', - - KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm', - - KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm', - - KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm', - - KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', - - KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm', - - KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', - KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', - KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', - KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', - - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', - - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', - - KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', - KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', - KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', - - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', - - KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', - KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', - KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', - KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', - KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', - KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', - KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', - - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', - - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', - - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', - KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', - - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', - KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', - KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', - KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16', - KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', - KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', - - KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', - - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q' -} - -class EpilogueScheduleType(enum.Enum): - ScheduleAuto = enum_auto() - EpilogueTransposed = enum_auto() - NoSmemWarpSpecialized = enum_auto() - PtrArrayNoSmemWarpSpecialized = enum_auto() - NoSmemWarpSpecialized1Sm = enum_auto() - NoSmemWarpSpecialized2Sm = enum_auto() - FastF32NoSmemWarpSpecialized1Sm = enum_auto() - FastF32NoSmemWarpSpecialized2Sm = enum_auto() - BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() - BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() - PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() - PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() - PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() - PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() - PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() - PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() - TmaWarpSpecialized = enum_auto() - TmaWarpSpecializedCooperative = enum_auto() - TmaWarpSpecialized1Sm = enum_auto() - TmaWarpSpecialized2Sm = enum_auto() - PtrArrayTmaWarpSpecialized1Sm = enum_auto() - PtrArrayTmaWarpSpecialized2Sm = enum_auto() - PtrArrayTmaWarpSpecializedPingpong = enum_auto() - PtrArrayTmaWarpSpecializedCooperative = enum_auto() - -# -EpilogueScheduleTag = { - EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', - EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', - EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', - EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', - EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', - EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', - EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', - EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', - EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', - EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', - EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', - EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', - EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', - EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', - EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', -} - -# -EpilogueScheduleSuffixes = { - EpilogueScheduleType.ScheduleAuto: '', - EpilogueScheduleType.EpilogueTransposed: '', - EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', - EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', - EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', - EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', - EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', - EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', - EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', - EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', - EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', - EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', - EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', - EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', - EpilogueScheduleType.TmaWarpSpecialized1Sm: '', - EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', -} - -class EpilogueFunctor3x(enum.Enum): - LinearCombination = enum_auto() - LinearCombinationBlockScaleFactor = enum_auto() - -# -EpilogueFunctor3xTag = { - EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', - EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', -} - -# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) -def is_tma_epilogue(epilogue_schedule_type): - return epilogue_schedule_type in [ - EpilogueScheduleType.ScheduleAuto, - EpilogueScheduleType.TmaWarpSpecialized, - EpilogueScheduleType.TmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecialized1Sm, - EpilogueScheduleType.TmaWarpSpecialized2Sm, - EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, - EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, - EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, - EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, - ] - -def to_grouped_schedule(schedule, grouped): - if not grouped: - return schedule - - group_schedule_map = { - # SM90 - KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, - KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, - KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, - EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, - EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, - # SM100 - KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, - KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, - KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, - KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, - KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, - KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, - KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, - KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, - KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, - EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, - EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, - EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, - EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, - EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, - # SM103 - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, - KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, - } - - return group_schedule_map[schedule] - -class TileSchedulerType(enum.Enum): - Default = enum_auto() - Persistent = enum_auto() - StreamK = enum_auto() -# -TileSchedulerTag = { - TileSchedulerType.Default: 'void', - TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', - TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', -} - -# -TileSchedulerSuffixes = { - TileSchedulerType.Default: '', - TileSchedulerType.Persistent: '', - TileSchedulerType.StreamK: '_stream_k', -} - -################################################################################################### - -# -class SideMode(enum.Enum): - Left = enum_auto() - Right = enum_auto() - -# -SideModeTag = { - SideMode.Left: 'cutlass::SideMode::kLeft', - SideMode.Right: 'cutlass::SideMode::kRight' -} - -# -ShortSideModeNames = { - SideMode.Left: 'ls', - SideMode.Right: 'rs' -} - -################################################################################################### - -# -class FillMode(enum.Enum): - Lower = enum_auto() - Upper = enum_auto() - -# -FillModeTag = { - FillMode.Lower: 'cutlass::FillMode::kLower', - FillMode.Upper: 'cutlass::FillMode::kUpper' -} - -# -ShortFillModeNames = { - FillMode.Lower: 'l', - FillMode.Upper: 'u' -} - -################################################################################################### - -# -class DiagType(enum.Enum): - NonUnit = enum_auto() - Unit = enum_auto() - -# -DiagTypeTag = { - DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', - DiagType.Unit: 'cutlass::DiagType::kUnit' -} - -# -ShortDiagTypeNames = { - DiagType.NonUnit: 'nu', - DiagType.Unit: 'un' -} - -################################################################################################### - -# -class OpcodeClass(enum.Enum): - Simt = enum_auto() - TensorOp = enum_auto() - WmmaTensorOp = enum_auto() - SparseTensorOp = enum_auto() - BlockScaledTensorOp = enum_auto() - - -OpcodeClassNames = { - OpcodeClass.Simt: 'simt', - OpcodeClass.TensorOp: 'tensorop', - OpcodeClass.WmmaTensorOp: 'wmma_tensorop', - OpcodeClass.SparseTensorOp: 'sptensorop', - OpcodeClass.BlockScaledTensorOp: 'bstensorop' -} - -OpcodeClassTag = { - OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', - OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', - OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', - OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', - OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' -} - -################################################################################################### - -# -class OperationKind(enum.Enum): - Gemm = enum_auto() - RankK = enum_auto() - Rank2K = enum_auto() - Trmm = enum_auto() - Symm = enum_auto() - Conv2d = enum_auto() - Conv3d = enum_auto() - -# -OperationKindNames = { - OperationKind.Gemm: 'gemm' - , OperationKind.RankK: 'rank_k' - , OperationKind.Rank2K: 'rank_2k' - , OperationKind.Trmm: 'trmm' - , OperationKind.Symm: 'symm' - , OperationKind.Conv2d: 'conv2d' - , OperationKind.Conv3d: 'conv3d' -} - -# -class Target(enum.Enum): - library = enum_auto() -# -ArchitectureNames = { - 50: 'maxwell', - 60: 'pascal', - 61: 'pascal', - 70: 'volta', - 75: 'turing', - 80: 'ampere', - 89: 'ada', - 90: 'hopper' -} - -# -SharedMemPerCC = { - 70: 96, # 96KB of SMEM - 72: 96, # 96KB of SMEM - 75: 64, # 64KB of SMEM - 80: 163, # 163KB of SMEM - 1KB reserved for the driver - 86: 99, # 99KB of SMEM - 1KB reserved for the driver - 87: 163, # 163KB of SMEM - 1KB reserved for the driver - 89: 99, # 99KB of SMEM - 1KB reserved for the driver - 90: 227, # 227KB of SMEM - 1KB reserved for the driver - 100: 227, # 227KB of SMEM - 1KB reserved for the driver -} - -################################################################################################### - -# -def SubstituteTemplate(template, values): - text = template - changed = True - while changed: - changed = False - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - if newtext != text: - changed = True - text = newtext - return text - -################################################################################################### - -# -class GemmKind(enum.Enum): - Gemm = enum_auto() - Sparse = enum_auto() - Universal = enum_auto() - Universal3x = enum_auto() - SparseUniversal3x = enum_auto() - PlanarComplex = enum_auto() - PlanarComplexArray = enum_auto() - Grouped = enum_auto() - BlockScaledUniversal3x = enum_auto() - GroupedUniversal3x = enum_auto() - GroupedBlockScaledUniversal3x = enum_auto() - BlockwiseUniversal3x = enum_auto() - GroupedBlockwiseUniversal3x = enum_auto() - -# -GemmKindNames = { - GemmKind.Gemm: "gemm", - GemmKind.Sparse: "spgemm", - GemmKind.Universal: "gemm", - GemmKind.Universal3x: "gemm", - GemmKind.SparseUniversal3x: "spgemm", - GemmKind.PlanarComplex: "gemm_planar_complex", - GemmKind.PlanarComplexArray: "gemm_planar_complex_array", - GemmKind.Grouped: "gemm_grouped", - GemmKind.BlockScaledUniversal3x: "gemm", - GemmKind.GroupedUniversal3x: "gemm_grouped", - GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", - GemmKind.BlockwiseUniversal3x: "gemm", - GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped" -} - -# -class RankKKind(enum.Enum): - Universal = enum_auto() - -# -RankKKindNames = { - RankKKind.Universal: "rank_k" -} - -# -class TrmmKind(enum.Enum): - Universal = enum_auto() - -# -TrmmKindNames = { - TrmmKind.Universal: "trmm" -} - -# -class SymmKind(enum.Enum): - Universal = enum_auto() - -# -SymmKindNames = { - SymmKind.Universal: "symm" -} - -# -class EpilogueFunctor(enum.Enum): - LinearCombination = enum_auto() - LinearCombinationClamp = enum_auto() - -# -EpilogueFunctorTag = { - EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', - EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', -} - -# -class MixedInputMode(enum.Enum): - ConvertOnly = enum_auto() - ScaleOnly = enum_auto() - ScaleWithZeroPoint = enum_auto() - -# -class SwizzlingFunctor(enum.Enum): - Identity1 = enum_auto() - Identity2 = enum_auto() - Identity4 = enum_auto() - Identity8 = enum_auto() - Horizontal = enum_auto() - StridedDgradIdentity1 = enum_auto() - StridedDgradIdentity4 = enum_auto() - StridedDgradHorizontal = enum_auto() - StreamK = enum_auto() - -# -SwizzlingFunctorTag = { - SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', - SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', - SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', - SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', - SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', - SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', - SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', - SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', - SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', -} - -# -class GroupScheduleMode(enum.Enum): - Device = enum_auto(), - Host = enum_auto() - -# -GroupScheduleModeTag = { - GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', - GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' -} - -# -ShortGroupScheduleModeNames = { - GroupScheduleMode.Device: 'Device', - GroupScheduleMode.Host: 'Host' -} - -################################################################################################### - -# -class ConvKind(enum.IntEnum): - Fprop = 0 - Dgrad = 1 - Wgrad = 2 - -# -ConvKindTag = { - ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', - ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', - ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' -} - -ConvKindNames = { - ConvKind.Fprop: 'fprop', - ConvKind.Dgrad: 'dgrad', - ConvKind.Wgrad: 'wgrad', -} - -class ConvMode(enum.IntEnum): - CrossCorrelation = 0 - Convolution = 1 - -# -class IteratorAlgorithm(enum.Enum): - Analytic = 0 - Optimized = 1 - FixedChannels = 2 - FewChannels = 3 - FixedStrideDilation = 4 - -# -IteratorAlgorithmTag = { - IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', - IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', - IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', - IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', - IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' -} - -IteratorAlgorithmNames = { - IteratorAlgorithm.Analytic: 'analytic', - IteratorAlgorithm.Optimized: 'optimized', - IteratorAlgorithm.FixedChannels: 'fixed_channels', - IteratorAlgorithm.FewChannels: 'few_channels', - IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' -} - -# -class StrideSupport(enum.Enum): - Strided = 0 - Unity = 1 - Fixed = 2 - -# -StrideSupportTag = { - StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', - StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', - StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' -} - -StrideSupportNames = { - StrideSupport.Strided: '', - StrideSupport.Unity: 'unity_stride', - StrideSupport.Fixed: 'fixed_stride' -} - -# -class GroupMode(enum.Enum): - NoneGroup = enum_auto() # dense conv (G=1) - SingleGroup = enum_auto() # grouped convolution (single group per CTA) - MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) - Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) - -# -GroupModeTag = { - GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', - GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', - GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', - GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', -} - -GroupModeNames = { - GroupMode.NoneGroup: '', - GroupMode.SingleGroup: 'single_group', - GroupMode.MultipleGroup: 'multiple_group', - GroupMode.Depthwise: 'depthwise', -} - -DynamicClusterShape = [0, 0, 1] - -################################################################################################### - -# -class MathInstruction: - def __init__(self, - instruction_shape, \ - element_a, element_b, element_accumulator, \ - opcode_class, math_operation = MathOperation.multiply_add \ - , element_scale_factor = None - ): - - self.instruction_shape = instruction_shape - self.element_a = element_a - self.element_b = element_b - self.element_accumulator = element_accumulator - self.opcode_class = opcode_class - self.math_operation = math_operation - self.element_scale_factor = element_scale_factor - -# -class TileDescription: - - def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None): - self.threadblock_shape = threadblock_shape - self.tile_shape = threadblock_shape - self.stages = stages - self.warp_count = warp_count - self.math_instruction = math_instruction - self.minimum_compute_capability = min_compute - self.maximum_compute_capability = max_compute - self.cluster_shape = cluster_shape - self.explicit_vector_sizes = explicit_vector_sizes - - def procedural_name(self): - if self.minimum_compute_capability >= 90: - return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( - tbm = self.threadblock_shape[0], - tbn = self.threadblock_shape[1], - tbk = self.threadblock_shape[2], - cm = self.cluster_shape[0], - cn = self.cluster_shape[1], - ck = self.cluster_shape[2], - s = self.stages) - else: - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) - -# -class Direct2dConvFixedStrideDilationTileDescription: - def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): - self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] - self.threadblock_output_shape = threadblock_output_shape - self.filter_shape = filter_shape - self.stages = stages - self.warp_count = warp_count - self.stride = stride - self.dilation = dilation - self.math_instruction = math_instruction - self.minimum_compute_capability = min_compute - self.maximum_compute_capability = max_compute - - def procedural_name(self): - str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], - self.threadblock_shape[1], - self.threadblock_shape[2], - self.threadblock_output_shape[0], - self.threadblock_output_shape[1], - self.threadblock_output_shape[2], - self.threadblock_output_shape[3], - self.stages, - self.filter_shape[0], - self.filter_shape[1]) - # Fixed Strided and dilation - if self.stride != [-1, -1] and self.dilation != [-1, -1]: - str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], - self.stride[1], - self.dilation[0], - self.dilation[1]) - return str_name - -# -class Direct2dConvFixedStrideDilationTileDescription: - def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): - self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] - self.threadblock_output_shape = threadblock_output_shape - self.filter_shape = filter_shape - self.stages = stages - self.warp_count = warp_count - self.stride = stride - self.dilation = dilation - self.math_instruction = math_instruction - self.minimum_compute_capability = min_compute - self.maximum_compute_capability = max_compute - - def procedural_name(self): - str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], - self.threadblock_shape[1], - self.threadblock_shape[2], - self.threadblock_output_shape[0], - self.threadblock_output_shape[1], - self.threadblock_output_shape[2], - self.threadblock_output_shape[3], - self.stages, - self.filter_shape[0], - self.filter_shape[1]) - # Fixed Strided and dilation - if self.stride != [-1, -1] and self.dilation != [-1, -1]: - str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], - self.stride[1], - self.dilation[0], - self.dilation[1]) - return str_name - -# -class TensorDescription: - def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): - self.element = element - self.layout = layout - self.alignment = alignment - self.complex_transform = complex_transform - -# -class SymmetricTensorDescription: - def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): - self.element = element - self.layout = layout - self.fill_mode = fill_mode - self.alignment = alignment - self.complex_transform = complex_transform - self.side_mode = side_mode - -# -class TriangularTensorDescription: - def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): - self.element = element - self.layout = layout - self.side_mode = side_mode - self.fill_mode = fill_mode - self.diag_type = diag_type - self.alignment = alignment - self.complex_transform = complex_transform - -# -def CalculateSmemUsage(operation): - cta_shape = operation.tile_description.threadblock_shape - stages = operation.tile_description.stages - - if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: - # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) - if DataTypeSize[operation.A.element] == 32: - elements_per_8b_md = 2 - elif DataTypeSize[operation.A.element] == 4: - elements_per_8b_md = 8 - else: - elements_per_8b_md = 4 - - smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ - DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ - cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md - else: - # Few BLAS3 operations only have A tensor - data_type_size_a = DataTypeSize[operation.A.element] - data_type_size_b = DataTypeSize[operation.A.element] - if operation.is_mixed_input(): - data_type_size_b = DataTypeSize[operation.B.element] - - smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ - data_type_size_b * cta_shape[1] * cta_shape[2] // 8 - - smem_usage = smem_per_stage * stages - return (smem_usage >> 10) - - -class GemmUniversalMode(enum.IntEnum): - """ - Types corresponding to GemmUniversalMode - """ - Gemm = 0 - GemmSplitKParallel = 1 - Batched = 2 - Array = 3 - - -class SplitKMode(enum.IntEnum): - """ - Types corresponding to SplitKMode - """ - NoneSplitK = 0 - Serial = 1 - Parallel = 2 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py deleted file mode 100644 index 5733ef26322794ee650dfa0c8c2b170bd8c6f3e5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py +++ /dev/null @@ -1,868 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for filtering CUTLASS library kernels and emitting library intitialization -and building code -""" - -import enum -import logging -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * - from cutlass_library.gemm_operation import * - from cutlass_library.rank_k_operation import * - from cutlass_library.rank_2k_operation import * - from cutlass_library.trmm_operation import * - from cutlass_library.symm_operation import * - from cutlass_library.conv2d_operation import * - from cutlass_library.conv3d_operation import * -except ImportError: - from library import * - from gemm_operation import * - from rank_k_operation import * - from rank_2k_operation import * - from trmm_operation import * - from symm_operation import * - from conv2d_operation import * - from conv3d_operation import * - -################################################################################################### -_LOGGER = logging.getLogger(__name__) - - -class EmitOperationKindAll: - """ - Emit the OperationKind-level CUTLASS library initialization code. - The code is generated in the {generated_path}/{operation_kind} directory - (e.g., tools/library/generated/gemm in the build directory, - for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file - (e.g., all_gemm_operations.cu for OperationKind=Gemm). - That file declares several functions in namespace cutlass::library. - The functions all have this form, - - void initialize_{configuration_name}(Manifest& manifest); - - The file also _defines_ the following function in that namespace. - - void initialize_all_{operation_kind}_operations(Manifest& manifest); - - That function calls all of the functions declared in this file. - Those functions are defined in subdirectories - (which this class does not create). - """ - - def __init__(self, generated_path, kind, args): - self.generated_path = generated_path - self.kind = kind - self.args = args - - self.header_template =""" -/* - Generated by manifest.py - Do not edit. -*/ - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.entry_template = """ - -// -// Entry point to construct operations -// -void initialize_all_${operation_name}_operations(Manifest &manifest) { -""" - self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" - self.configuration_template =" initialize_${configuration_name}(manifest);\n" - - self.epilogue_template ="""} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -""" - - # - def __enter__(self): - _LOGGER.debug("*** EmitOperationKindAll::__enter__") - - self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) - _LOGGER.debug('*** operation_path (directory to create): ' + - str(self.operation_path)); - os.makedirs(self.operation_path, exist_ok=True) - - self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu") - _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") - - self.top_level_file = open(self.top_level_path, "w") - self.top_level_file.write(self.header_template) - - self.source_files = [self.top_level_path,] - - self.configurations = [] - - return self - - # - def emit(self, operations): - _LOGGER.debug('*** EmitOperationKindAll::emit') - _LOGGER.debug(f"*** len(operations): {len(operations)}") - _LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}") - - for min_cc, configurations in sorted(operations.items()): - _LOGGER.debug(f"*** min_cc={min_cc}") - - for configuration_name, _ in configurations.items(): - _LOGGER.debug(f"*** configuration_name={configuration_name}") - self.configurations.append(configuration_name) - self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) - - # - def __exit__(self, exception_type, exception_value, traceback): - _LOGGER.debug("*** EmitOperationKindAll::__exit__") - - self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) - - for configuration_name in self.configurations: - self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) - - self.top_level_file.write(self.epilogue_template) - self.top_level_file.close() - - -class EmitOperationKindLibrary: - """ - Emit the CUTLASS library initialization code for each OperationKind. - The code is generated in the directory - {generated_path}/{operation_kind}/{min_cc} - (e.g., tools/library/generated/gemm/90 in the build directory, - for min_cc=90 and OperationKind=Gemm), in the file - all_sm{min_cc}_{operation_kind}_operations.cu - (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm). - The min_cc variable here indicates the minimum GPU architecture version - that the things to be initialized require. - For example, min_cc=90 indicates sm90. - - That file declares several functions in namespace cutlass::library. - The functions all have this form, - - void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest); - - where extended_name is operation.extended_name() for all the operations - given to the emit method (which see below). (All operations for a given - configuration_name are guaranteed to have the same extended_name().) - - The file also _defines_ the following function in that namespace. - - void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest); - - That function calls all of the functions declared in this file. - Those functions are defined in subdirectories. - The mapping from OperationKind to emitter handles the details - of what happens in each of those subdirectories. - """ - - def __init__(self, generated_path, min_cc, kind, args): - self.generated_path = generated_path - self.min_cc = min_cc - self.kind = kind - self.args = args - self.emitters = { - OperationKind.Gemm: EmitGemmConfigurationLibrary, - OperationKind.Conv2d: EmitConv2dConfigurationLibrary, - OperationKind.Conv3d: EmitConv3dConfigurationLibrary, - OperationKind.RankK: EmitRankKConfigurationLibrary, - OperationKind.Rank2K: EmitRank2KConfigurationLibrary, - OperationKind.Trmm: EmitTrmmConfigurationLibrary, - OperationKind.Symm: EmitSymmConfigurationLibrary - } - - self.header_template =""" -/* - Generated by manifest.py - Do not edit. -*/ - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - self.entry_template = """ - -// -// Entry point to construct operations -// -void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { -""" - self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" - self.configuration_template = " initialize_${configuration_name}(manifest);\n" - self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" - self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" - self.epilogue_template ="""} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -""" - - # - def __enter__(self): - _LOGGER.debug("*** EmitOperationKindLibrary::__enter__") - _LOGGER.debug(f"*** generated_path: {str(self.generated_path)}") - _LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}") - _LOGGER.debug(f"*** min_cc: {self.min_cc}") - - self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc)) - _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}") - os.makedirs(self.operation_path) - - self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu") - _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") - - self.top_level_file = open(self.top_level_path, "w") - self.top_level_file.write(self.header_template) - - self.source_files = {} - - # Each {operation_kind x cc} combination is further decomposed by the instruction - # types used. This dictionary used to track the file handles for the top-level - # files of each subclass - self.subclass_files = {} - - # Configurations in each sub class - self.subclass_configurations = {} - - return self - - # - def emit(self, configuration_name, operations): - _LOGGER.debug("*** EmitOperationKindLibrary::emit") - _LOGGER.debug(f"*** configuration_name: {configuration_name}") - - assert len(operations) > 0 - - # The extended name for all operations of a given configuration_name is guaranteed - # to be the same because extended_name() is used in defining configuration_name. Thus, - # we can safely use the extended_name() of the first operation. - extended_name = operations[0].extended_name() - _LOGGER.debug('*** extended_name (for all ops): ' + extended_name) - - # Create a directory for operations with this subclass if it does not exist - if extended_name not in self.subclass_files: - subclass_path = os.path.join(self.operation_path, extended_name) - _LOGGER.debug(f"*** subclass_path: {str(subclass_path)}") - os.mkdir(subclass_path) - - self.subclass_configurations[extended_name] = [] - - # Open a new top-level file for this sub class - subclass_top_level_path = os.path.join( - subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu") - _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' + - 'OperationKind): ' + str(subclass_top_level_path)) - - self.subclass_files[extended_name] = open(subclass_top_level_path, "w") - self.subclass_files[extended_name].write(self.header_template) - - self.source_files[extended_name] = [subclass_top_level_path] - - subclass_dir = os.path.dirname(self.subclass_files[extended_name].name) - _LOGGER.debug('*** subclass_dir: ' + str(subclass_dir)) - - with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter: - for operation in operations: - configuration_emitter.emit(operation) - - _LOGGER.debug('*** configuration_emitter.configuration_path: ' + - str(configuration_emitter.configuration_path)) - self.source_files[extended_name].append(configuration_emitter.configuration_path) - - self.subclass_configurations[extended_name].append(configuration_name) - self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) - - # - def __exit__(self, exception_type, exception_value, traceback): - _LOGGER.debug("*** EmitOperationKindLibrary::__exit__") - for subclass_name, subclass_file in sorted(self.subclass_files.items()): - subclass_cfg = { - 'min_cc': str(self.min_cc), - 'subclass_name': subclass_name, - 'operation_name': OperationKindNames[self.kind] - } - self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg)) - - self.top_level_file.write( - SubstituteTemplate(self.entry_template, { - 'min_cc': str(self.min_cc), - 'subclass_name': '', - 'operation_name': OperationKindNames[self.kind] - })) - - # Finish and close all subclass files - for subclass_name, subclass_file in sorted(self.subclass_files.items()): - subclass_cfg = { - 'min_cc': str(self.min_cc), - 'subclass_name': subclass_name, - 'operation_name': OperationKindNames[self.kind] - } - subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg)) - - for configuration in self.subclass_configurations[subclass_name]: - subclass_file.write( - SubstituteTemplate(self.configuration_template, { - 'configuration_name': configuration - })) - - subclass_file.write(self.epilogue_template) - subclass_file.close() - - # Write the call to initialize_all for this subclass to the top-level file - self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg)) - - self.top_level_file.write(self.epilogue_template) - self.top_level_file.close() - -class EmitInterfaceLibrary: - """ - Emit the topmost-level CUTLASS library initialization code. - The code is generated in the generated_path directory - (e.g., tools/library/generated in the build directory), - in the initialize_all.cpp file. - That file declares several functions in namespace cutlass::library. - The functions all have this form, - - void initialize_all_{operation_kind}_operations(Manifest& manifest); - - where {operation_kind} abbreviates the "kind" of operation - (e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution, - or trmm for triangular solve with multiple right-hand sides). - The definitions of these functions live in subdirectories. - - The file also _defines_ the following function in that namespace. - - void initialize_all(Manifest& manifest); - - That function first prepares the manifest, and then - calls all of the functions declared in this file. - """ - - def __init__(self, generated_path, operation_count, args): - self.generated_path = generated_path - self.args = args - - self.prototypes = [] - self.fn_calls = [] - self.operation_count = str(operation_count) - - self.top_level_hdr_template = ''' -/* - Generated by manifest.py - Do not edit. -*/ -''' - self.top_level_prologue = ''' - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -namespace cutlass { -\tnamespace library { - -${prototypes} -''' - - self.top_level_initialize_kind = ''' -\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) { -${fn_calls} -\t\t} -''' - - self.top_level_initialize = ''' -\t\tvoid initialize_all(Manifest &manifest) { -\t\t\tmanifest.reserve(${operation_count});\n -${fn_calls} -\t\t} -''' - - self.top_level_suffix = ''' -\t} // namespace library -} // namespace cutlass - -''' - - # - def __enter__(self): - _LOGGER.debug("*** EmitInterfaceLibrary::__enter__") - - self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp') - _LOGGER.debug("*** top_level_path: " + str(self.top_level_path)) - - self.top_level_file = open(self.top_level_path, "w") - self.top_level_file.write(self.top_level_hdr_template) - - self.source_files = [self.top_level_path,] - - return self - - # - def emit(self, operation_name): - _LOGGER.debug("*** EmitInterfaceLibrary::emit") - _LOGGER.debug("*** operation_name: " + operation_name) - - self.prototypes.append(SubstituteTemplate( - "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);", - {'operation_kind': operation_name})) - - self.fn_calls.append(SubstituteTemplate( - "\t\t\tinitialize_all_${operation_kind}_operations(manifest);", - {'operation_kind': operation_name})) - - # - def __exit__(self, exception_type, exception_value, traceback): - _LOGGER.debug("*** EmitInterfaceLibrary::__exit__") - - self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)})) - - # Write out initialize_all method - self.top_level_file.write(SubstituteTemplate(self.top_level_initialize, - {'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)})) - - self.top_level_file.write(self.top_level_suffix) - self.top_level_file.close() - -################################################################################################### -################################################################################################### - -class Options: - def __init__(self): - pass - -################################################################################################### - -# -class Manifest: - - # - def __init__(self, args = None): - self.operations = {} - self.args = args - self.operation_count = 0 - self.operations_by_name = {} - - self.kernel_filter = '' - self.kernel_filter_list = [] - self.kernel_names = [] - self.operations_enabled = [] - self.selected_kernels = [] - self.ignore_kernel_names = [] - self.exclude_kernel_names = [] - self.compute_capabilities_baseline = [50,] - self.compute_capabilities_feature_set = ['50',] - self.curr_build_dir = '.' - self.filter_by_cc = True - - if self.args: - self.kernel_filter = self.args.kernels - self.curr_build_dir = args.curr_build_dir - - # A common user error is to use commas instead of semicolons. - if ',' in args.architectures: - raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) - - self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] - self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) - - if args.filter_by_cc in ['false', 'False', '0']: - self.filter_by_cc = False - - if args.operations == 'all': - self.operations_enabled = [] - else: - operations_list = [ - OperationKind.Gemm - , OperationKind.Conv2d - , OperationKind.Conv3d - , OperationKind.RankK - , OperationKind.Trmm - , OperationKind.Symm - ] - self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] - - if args.kernels == 'all': - self.kernel_names = [] - else: - self.kernel_names = [x for x in args.kernels.split(',') if x != ''] - - self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] - self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] - - if args.kernel_filter_file is None: - self.kernel_filter_list = [] - else: - self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) - _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( - filter_count = len(self.kernel_filter_list), - filter_file = args.kernel_filter_file)) - - self.operation_count = 0 - self.operations_by_name = {} - self.disable_full_archs_compilation = args.disable_full_archs_compilation - self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' - self.instantiation_level = 0 - try: - self.instantiation_level = int(args.instantiation_level) - except ValueError: - self.instantiation_level = 0 - - def add_kernel_filter(self, filter_str): - filter_re = re.compile(filter_str) - - self.kernel_filter_list.append(filter_re) - - def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): - # Non-negative integer which determines how many kernels are instantiated. - # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. - # increasing first digit reduces schedule / mixed type pruning, - # increasing second digit generates more cluster sizes, - # increasing third digit generates more MMA multipliers, - # increasing fourth digit generates more instruction shapes. - - if self.instantiation_level > 0: - return self.instantiation_level - - elif self.is_kernel_filter_set_to_all: - return exhaustive_level - - elif self.kernel_filter == '': - return pruned_level - - else: - return default_level - - - def get_kernel_filters(self, kernelListFile): - if os.path.isfile(kernelListFile): - with open(kernelListFile, 'r') as fileReader: - lines = [line.rstrip() for line in fileReader if not line.startswith("#")] - - lines = [re.compile(line) for line in lines if line] - return lines - else: - return [] - - # - def filter_out_kernels(self, kernel_name, kernel_filter_list): - - for kernel_filter_re in kernel_filter_list: - if kernel_filter_re.search(kernel_name) is not None: - return True - - return False - - - # - def _filter_string_matches(self, filter_string, haystack): - ''' Returns true if all substrings appear in the haystack in order''' - substrings = filter_string.split('*') - for sub in substrings: - idx = haystack.find(sub) - if idx < 0: - return False - haystack = haystack[idx + len(sub):] - return True - - # - def filter(self, operation): - ''' Filtering operations based on various criteria''' - - # filter based on compute capability - enabled = not (self.filter_by_cc) - - for cc in self.compute_capabilities_baseline: - - if cc >= operation.tile_description.minimum_compute_capability and \ - cc <= operation.tile_description.maximum_compute_capability and \ - (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)): - - enabled = True - break - - if not enabled: - return False - - if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: - return False - - name = operation.procedural_name() - - # eliminate duplicates - if name in self.operations_by_name.keys(): - return False - - # Filter based on list of valid substrings - if len(self.kernel_names): - enabled = False - - # compare against the include list - for name_substr in self.kernel_names: - if self._filter_string_matches(name_substr, name): - _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.") - enabled = True - break - else: - _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.") - - # compare against the exclude list - for name_substr in self.ignore_kernel_names: - if self._filter_string_matches(name_substr, name): - _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.") - enabled = False - break - else: - _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.") - - if len(self.kernel_filter_list) > 0: - if self.filter_out_kernels(name, self.kernel_filter_list): - _LOGGER.debug(f"Kernel {name} matched via kernel filter file.") - enabled = True - else: - _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.") - enabled = False - - # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect - # if CUTLASS_LIBRARY_KERNELS was specified. - # Changing that would break backwards compatibility. - # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS, - # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified. - for name_substr in self.exclude_kernel_names: - if self._filter_string_matches(name_substr, name): - _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.") - enabled = False - break - else: - _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.") - - # TODO: filter based on compute data type - return enabled - # - - # - def append(self, operation): - ''' - Inserts the operation. - - operation_kind -> configuration_name -> [] - ''' - - if self.filter(operation): - - self.selected_kernels.append(operation.procedural_name()) - - self.operations_by_name[operation.procedural_name()] = operation - - # add the configuration - configuration_name = operation.configuration_name() - - # Split operations by minimum CC - min_cc = operation.arch - - if operation.operation_kind not in self.operations.keys(): - self.operations[operation.operation_kind] = {} - - if min_cc not in self.operations[operation.operation_kind]: - self.operations[operation.operation_kind][min_cc] = {} - - if configuration_name not in self.operations[operation.operation_kind][min_cc].keys(): - self.operations[operation.operation_kind][min_cc][configuration_name] = [] - - self.operations[operation.operation_kind][min_cc][configuration_name].append(operation) - self.operation_count += 1 - else: - _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name())) - # - - def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): - with open(manifest_path, "w") as manifest_file: - - target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE - """, { }) - manifest_file.write(target_text + '\n\n') - manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/'))) - generated_path = os.path.join(self.curr_build_dir, 'generated') - for kind in self.operations.keys(): - kind_str = OperationKindNames[kind] - all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/') - manifest_file.write(f" {all_kind_file}\n") - manifest_file.write(')\n\n') - - for kind in self.operations.keys(): - for min_cc in sorted(self.operations[kind].keys()): - for subclass in sorted(source_files[kind][min_cc].keys()): - target_text = SubstituteTemplate("""cutlass_add_cutlass_library( - SUFFIX ${kind}_sm${min_cc}_${subclass} -""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) - manifest_file.write(target_text + '\n\n') - - for source_file in source_files[kind][min_cc][subclass]: - manifest_file.write(" %s\n" % str(source_file.replace('\\', '/'))) - - manifest_file.write(")\n") - - if self.disable_full_archs_compilation: - self.emit_disable_full_archs_compilation(manifest_file, source_files) - - def emit_disable_full_archs_compilation(manifest_file, source_files): - def for_hopper(name): - pass - - def for_ampere(name): - return "16816" in name or \ - "16832" in name or \ - "16864" in name or \ - ("1688" in name and "tf32" in name) - - def for_turing(name): - return ("1688" in name and "tf32" not in name) or \ - "8816" in name - - def for_volta(name): - return "884" in name - - def is_cpp(name): - return name.endswith(".cpp") - - def get_src_archs_str_given_requested_cuda_archs(archs, source_file): - intersected_archs = archs & set(self.compute_capabilities_baseline) - if intersected_archs == set(): - raise RuntimeError( - """ - Empty archs set for file {} after taking - the intersection of {} (global requested archs) and - {} (per file requested archs) - """.format(source_file, set(self.compute_capabilities_baseline), archs)) - else: - return " ".join(map(str, intersected_archs)) - - for min_cc in sorted(source_files.keys()): - for source_file in source_files[min_cc]: - if is_cpp(source_file): - continue # skip because source is cpp - elif for_ampere(source_file): - archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file) - elif for_turing(source_file): - archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file) - elif for_volta(source_file): - archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file) - else: - raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file)) - - manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str)) - - # - def emit(self, target = GeneratorTarget.Library): - - operation_emitters = { - GeneratorTarget.Library: EmitOperationKindLibrary - } - - # Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d) - kind_emitters = { - GeneratorTarget.Library: EmitOperationKindAll - } - - interface_emitters = { - GeneratorTarget.Library: EmitInterfaceLibrary - } - - generated_path = os.path.join(self.curr_build_dir, 'generated') - - # create generated/ - if os.path.exists(generated_path): - shutil.rmtree(generated_path) - - os.mkdir(generated_path) - - with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: - top_level_path = iface_emitter.top_level_path - for operation_kind in self.operations.keys(): - iface_emitter.emit(OperationKindNames[operation_kind]) - - source_files = {} - for kind in self.operations.keys(): - source_files[kind] = {} - for min_cc in self.operations[kind].keys(): - source_files[kind][min_cc] = {} - - for operation_kind, ops in self.operations.items(): - for min_cc, configurations in sorted(ops.items()): - with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter: - for configuration_name, operations in configurations.items(): - _LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.") - operation_kind_emitter.emit(configuration_name, operations) - - for subclass, files in operation_kind_emitter.source_files.items(): - if subclass not in source_files[operation_kind][min_cc]: - source_files[operation_kind][min_cc][subclass] = [] - source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass]) - - # Emit top level all_{gemm, conv2d, ...}_operations.cu files - with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: - operation_kind_emitter.emit(ops) - - # write the manifest.cmake file containing paths from all targets - manifest_path = os.path.join(generated_path, "manifest.cmake") - - self.emit_manifest_cmake(manifest_path, top_level_path, source_files) - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py deleted file mode 100644 index 29ef056f26f914a9c3c33e13900c33642ad2f1b7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py +++ /dev/null @@ -1,438 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting Rank2K kernels -""" - -import enum -import functools -import operator -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - - -################################################################################################### -# -# Data structure modeling a Rank K update operation -# -################################################################################################### - -# -class Rank2KOperation: - # - def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ - blas_mode = BlasMode.symmetric): - - self.blas_mode = blas_mode - self.operation_kind = OperationKind.Rank2K - self.arch = arch - self.tile_description = tile_description - self.rank_k_kind = rank_k_kind - # tensor A and B have same data type and layout - self.A = A - self.B = A - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - return False - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def is_planar_complex(self): - return False - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and' - } - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k' - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] - ) - return "%s" % (ShortLayoutTypeNames[self.A.layout]) - - # - def fill_mode_name(self): - return "%s" % (ShortFillModeNames[self.C.fill_mode]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.A.alignment, self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'fill_mode': self.fill_mode_name(), - 'alignment': "%d" % self.A.alignment, - } - ) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -# -class EmitRank2KUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self): - self.rank_k_template = """ -// Rank K operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Rank2K< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, ${fill_mode}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation} ->; -""" - self.rank_k_complex_template = """ -// Rank K operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Rank2K< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, ${fill_mode}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation}, - ${transform_a}, - ${transform_b}, - ${blas_mode} ->; -""" - - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - - warp_count = operation.tile_description.warp_count - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'fill_mode': FillModeTag[operation.C.fill_mode], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'split_k_serial': 'false', - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'blas_mode': BlasModeTag[operation.blas_mode] - } - - rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template - - return SubstituteTemplate(rank_k_template, values) - -################################################################################################### - - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitRank2KConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - RankKKind.Universal: EmitRank2KUniversalInstance, - } - - self.rank_k_kind_wrappers = { - RankKKind.Universal: 'Rank2KOperation', - } - - self.instance_template = { - RankKKind.Universal: """ -${compile_guard_start} - manifest.append(new ${rank_k_kind}< - Operation_${operation_name} - >("${operation_name}")); -${compile_guard_end} -""" - } - - self.header_template = """ -/* - Generated by rank_2k_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -#include "rank_2k_operation.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.initialize_function_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_${configuration_name}(Manifest &manifest) { - -""" - self.epilogue_template = """ - -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - emitter = self.instance_emitter[operation.rank_k_kind]() - - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py deleted file mode 100644 index 9841952332a170d6f401dbe34a0093540c166fb8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py +++ /dev/null @@ -1,427 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting RankK kernels -""" - -import enum -import functools -import operator -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - - -################################################################################################### -# -# Data structure modeling a Rank K update operation -# -################################################################################################### - -# -class RankKOperation: - # - def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ - blas_mode = BlasMode.symmetric): - - self.blas_mode = blas_mode - self.operation_kind = OperationKind.RankK - self.arch = arch - self.tile_description = tile_description - self.rank_k_kind = rank_k_kind - self.A = A - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - return False - - # - def is_mixed_input(self): - return False - - # - def is_planar_complex(self): - return False - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and' - } - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk' - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] - ) - return "%s" % (ShortLayoutTypeNames[self.A.layout]) - - # - def fill_mode_name(self): - return "%s" % (ShortFillModeNames[self.C.fill_mode]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.A.alignment, self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'fill_mode': self.fill_mode_name(), - 'alignment': "%d" % self.A.alignment, - } - ) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -# -class EmitRankKUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self): - self.rank_k_template = """ -// Rank K operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::RankK< - ${element_a}, ${layout_a}, - ${element_c}, ${layout_c}, ${fill_mode}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${split_k_serial}, - ${math_operation} ->; -""" - self.rank_k_complex_template = """ -// Rank K operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::RankK< - ${element_a}, ${layout_a}, - ${element_c}, ${layout_c}, ${fill_mode}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${split_k_serial}, - ${math_operation}, - ${transform_a}, - ${blas_mode} ->; -""" - - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - - warp_count = operation.tile_description.warp_count - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'fill_mode': FillModeTag[operation.C.fill_mode], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'split_k_serial': 'false', - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'blas_mode': BlasModeTag[operation.blas_mode] - } - - rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template - - return SubstituteTemplate(rank_k_template, values) - -################################################################################################### - - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitRankKConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - RankKKind.Universal: EmitRankKUniversalInstance, - } - - self.rank_k_kind_wrappers = { - RankKKind.Universal: 'RankKOperation', - } - - self.instance_template = { - RankKKind.Universal: """ -${compile_guard_start} - manifest.append(new ${rank_k_kind}< - Operation_${operation_name} - >("${operation_name}")); -${compile_guard_end} -""" - } - - self.header_template = """ -/* - Generated by rank_k_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -#include "rank_k_operation.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.initialize_function_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_${configuration_name}(Manifest &manifest) { - -""" - self.epilogue_template = """ - -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - emitter = self.instance_emitter[operation.rank_k_kind]() - - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py deleted file mode 100644 index 32e4376513679f06dc085ead068e258b3d8b5e72..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py +++ /dev/null @@ -1,342 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. -These shape and level pairs are defined as dicts, where keys are shapes and values are their -associated levels. If the user input level for that category (tcgen05 shape, cluster -size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. -Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. -Level 0 is always emitted. -""" - -try: - from .library import DynamicClusterShape -except: - from library import DynamicClusterShape - -SM100_CLUSTER_SHAPES_1SM = { - tuple(DynamicClusterShape) : 0, - # size 1 cluster - (1, 1, 1): 1, - # size 2 cluster - (1, 2, 1): 2, - (2, 1, 1): 5, - # size 4 clusters - (2, 2, 1): 6, - (1, 4, 1): 3, - (4, 1, 1): 6, - # size 8 clusters - (2, 4, 1): 7, - (4, 2, 1): 7, - (1, 8, 1): 8, - (8, 1, 1): 8, - # size 16 cluster - (4, 4, 1): 4, -} - -SM100_CLUSTER_SHAPES_2SM = { - tuple(DynamicClusterShape) : 0, - # size 2 cluster - (2, 1, 1): 1, - # size 4 clusters - (2, 2, 1): 2, - (4, 1, 1): 2, - # size 8 clusters - (2, 4, 1): 3, - (4, 2, 1): 3, - (8, 1, 1): 6, - # size 16 cluster - (4, 4, 1): 4, -} - -# MMA shapes - -# 16b Dense - -SM100_MMA_SHAPES_16b_DENSE_1SM = { - (64, 8, 16): 5, - (64, 16, 16): 2, - (64, 24, 16): 5, - (64, 32, 16): 2, - (64, 40, 16): 5, - (64, 48, 16): 5, - (64, 56, 16): 5, - (64, 64, 16): 2, - (64, 72, 16): 5, - (64, 80, 16): 5, - (64, 88, 16): 5, - (64, 96, 16): 5, - (64, 104, 16): 5, - (64, 112, 16): 5, - (64, 120, 16): 5, - (64, 128, 16): 0, - (64, 136, 16): 5, - (64, 144, 16): 5, - (64, 152, 16): 5, - (64, 160, 16): 5, - (64, 168, 16): 5, - (64, 176, 16): 5, - (64, 184, 16): 5, - (64, 192, 16): 3, - (64, 200, 16): 5, - (64, 208, 16): 5, - (64, 216, 16): 5, - (64, 224, 16): 5, - (64, 232, 16): 5, - (64, 240, 16): 5, - (64, 248, 16): 5, - (64, 256, 16): 3, - - (128, 16, 16): 2, - (128, 32, 16): 2, - (128, 48, 16): 5, - (128, 64, 16): 2, - (128, 80, 16): 5, - (128, 96, 16): 5, - (128, 112, 16): 5, - (128, 128, 16): 0, - (128, 144, 16): 5, - (128, 160, 16): 5, - (128, 176, 16): 5, - (128, 192, 16): 3, - (128, 208, 16): 5, - (128, 224, 16): 5, - (128, 240, 16): 5, - (128, 256, 16): 0, - -} - - -SM100_MMA_SHAPES_16b_DENSE_2SM = { - (128, 32, 16): 2, - (128, 64, 16): 2, - (128, 96, 16): 5, - (128, 128, 16): 0, - (128, 160, 16): 5, - (128, 192, 16): 5, - (128, 224, 16): 5, - (128, 256, 16): 0, - - (256, 32, 16): 2, - (256, 64, 16): 2, - (256, 96, 16): 5, - (256, 128, 16): 0, - (256, 160, 16): 5, - (256, 192, 16): 3, - (256, 224, 16): 5, - (256, 256, 16): 0, -} - -# TF32 Dense - -SM100_MMA_SHAPES_TF32_DENSE_1SM = { - (64, 8, 8): 5, - (64, 16, 8): 2, - (64, 24, 8): 5, - (64, 32, 8): 2, - (64, 40, 8): 5, - (64, 48, 8): 5, - (64, 56, 8): 5, - (64, 64, 8): 1, - (64, 72, 8): 5, - (64, 80, 8): 5, - (64, 88, 8): 5, - (64, 96, 8): 5, - (64, 104, 8): 5, - (64, 112, 8): 5, - (64, 120, 8): 5, - (64, 128, 8): 0, - (64, 136, 8): 5, - (64, 144, 8): 5, - (64, 152, 8): 5, - (64, 160, 8): 5, - (64, 168, 8): 5, - (64, 176, 8): 5, - (64, 184, 8): 5, - (64, 192, 8): 3, - (64, 200, 8): 5, - (64, 208, 8): 5, - (64, 216, 8): 5, - (64, 224, 8): 5, - (64, 232, 8): 5, - (64, 240, 8): 5, - (64, 248, 8): 5, - (64, 256, 8): 3, - - (128, 16, 8): 2, - (128, 32, 8): 2, - (128, 48, 8): 5, - (128, 64, 8): 2, - (128, 80, 8): 5, - (128, 96, 8): 5, - (128, 112, 8): 5, - (128, 128, 8): 0, - (128, 144, 8): 5, - (128, 160, 8): 5, - (128, 176, 8): 5, - (128, 192, 8): 3, - (128, 208, 8): 5, - (128, 224, 8): 5, - (128, 240, 8): 5, - (128, 256, 8): 0, - -} - -SM100_MMA_SHAPES_TF32_DENSE_2SM = { - (128, 32, 8): 2, - (128, 64, 8): 1, - (128, 96, 8): 5, - (128, 128, 8): 0, - (128, 160, 8): 5, - (128, 192, 8): 5, - (128, 224, 8): 5, - (128, 256, 8): 0, - - (256, 32, 8): 2, - (256, 64, 8): 1, - (256, 96, 8): 5, - (256, 128, 8): 0, - (256, 160, 8): 5, - (256, 192, 8): 5, - (256, 224, 8): 5, - (256, 256, 8): 0, -} - -# F8F6F4 -SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { - (64, 8, 32): 4, - (64, 16, 32): 4, - (64, 24, 32): 5, - (64, 32, 32): 3, - (64, 40, 32): 5, - (64, 48, 32): 5, - (64, 56, 32): 5, - (64, 64, 32): 2, - (64, 72, 32): 5, - (64, 80, 32): 5, - (64, 88, 32): 5, - (64, 96, 32): 5, - (64, 104, 32): 5, - (64, 112, 32): 5, - (64, 120, 32): 5, - (64, 128, 32): 0, - (64, 136, 32): 5, - (64, 144, 32): 5, - (64, 152, 32): 5, - (64, 160, 32): 5, - (64, 168, 32): 5, - (64, 176, 32): 5, - (64, 184, 32): 5, - (64, 192, 32): 5, - (64, 200, 32): 5, - (64, 208, 32): 5, - (64, 216, 32): 5, - (64, 224, 32): 5, - (64, 232, 32): 5, - (64, 240, 32): 5, - (64, 248, 32): 5, - (64, 256, 32): 0, - - (128, 16, 32): 4, - (128, 32, 32): 3, - (128, 48, 32): 5, - (128, 64, 32): 2, - (128, 80, 32): 5, - (128, 96, 32): 5, - (128, 112, 32): 5, - (128, 128, 32): 0, - (128, 144, 32): 5, - (128, 160, 32): 5, - (128, 176, 32): 5, - (128, 192, 32): 5, - (128, 208, 32): 5, - (128, 224, 32): 5, - (128, 240, 32): 5, - (128, 256, 32): 0, - -} - -SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { - (128, 32, 32): 3, - (128, 64, 32): 2, - (128, 96, 32): 5, - (128, 128, 32): 1, - (128, 160, 32): 5, - (128, 192, 32): 5, - (128, 224, 32): 5, - (128, 256, 32): 1, - - (256, 32, 32): 2, - (256, 64, 32): 2, - (256, 96, 32): 5, - (256, 128, 32): 0, - (256, 160, 32): 5, - (256, 192, 32): 5, - (256, 224, 32): 5, - (256, 256, 32): 0, -} - -# MXF8F6F4 -SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { - (128, 64, 32): 1, - (128, 128, 32): 0, - (128, 192, 32): 1, - (128, 256, 32): 0, -} - - -SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { - (256, 64, 32): 1, - (256, 128, 32): 0, - (256, 192, 32): 1, - (256, 256, 32): 0, - - -} - -# MXF4NVF4 -SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { - (128, 64, 64): 1, - (128, 128, 64): 0, - (128, 192, 64): 1, - (128, 256, 64): 0, -} - -SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { - # Multiples of 16 for N - (256, 64, 64): 1, - (256, 128, 64): 0, - (256, 192, 64): 1, - (256, 256, 64): 0, - -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py deleted file mode 100644 index 9bf24fe7f528020be4dcfc6ac41cfe949dd63be5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py +++ /dev/null @@ -1,661 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for enumerating CUTLASS library SM100 kernels -""" - -import argparse -import enum -from itertools import product -import math -import logging -import os.path -import shutil -import sys -import copy -from typing import Any, Optional, Sequence, Tuple, List, Union, Callable - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - -#### Step 0: define levels - -# One integer level controls multiple "generators" and how many -# combinations they generate. That is the "global" level. -# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and -# anything that is eventually involved in the Cartesian product -# which yields our kernel configurations. -# For simplicity, each generator defines their own levels, -# starting from 0. As a rule we assume 10 or fewer levels, making -# their level a digit. -# The "global" level simply stacks these digits and represents them -# as a single integer. -# -# For example, level 500 indicates cluster sizes are at level 5, MMA -# multipliers are at level 0, and WGMMA shapes are at level 0 as well. -# -# Here we define the global level to generator level mappings. - - -def get_tcgen05_level_from_global_level(global_level: int): - return global_level % 10 - -def get_mma_level_from_global_level(global_level: int): - return (global_level // 10) % 10 - - -def get_cluster_level_from_global_level(global_level: int): - return (global_level // 100) % 10 - - -def get_pruning_level_from_global_level(global_level: int): - return (global_level // 1000) % 10 - - -#### Step 1: generate MMA instruction shapes based on levels - -try: - from .sm100_shapes import * -except: - from sm100_shapes import * - -########### - -def generate_tf32_math_instructions_sm100(level: int): - """ - Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - tcgen05_level = get_tcgen05_level_from_global_level(level) - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - for shape in shapes_2sm: - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - return math_instructions_1sm, math_instructions_2sm - -def generate_16b_math_instructions_sm100(level: int): - """ - Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - tcgen05_level = get_tcgen05_level_from_global_level(level) - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - - for shape in shapes_2sm: - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - return math_instructions_1sm, math_instructions_2sm - - -def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): - """ - Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - enable_runtime_dtype: Whether to generate runtime dtype math instructions. - enable_compile_time_dtype: Whether to generate compile time dtype math instructions. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - - tcgen05_level = get_tcgen05_level_from_global_level(level) - pruning_level = get_pruning_level_from_global_level(level) - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - if enable_runtime_dtype: - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - if enable_compile_time_dtype: - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - if pruning_level >= 2: - math_instructions_1sm.append( - MathInstruction( - shape, - DataType.e5m2, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - for shape in shapes_2sm: - if enable_runtime_dtype: - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - if enable_compile_time_dtype: - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - if pruning_level >= 2: - math_instructions_2sm.append( - MathInstruction( - shape, - DataType.e5m2, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - return math_instructions_1sm, math_instructions_2sm - -def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): - """ - Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - enable_runtime_dtype: Whether to generate runtime dtype math instructions. - enable_compile_time_dtype: Whether to generate compile time dtype math instructions. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - - tcgen05_level = get_tcgen05_level_from_global_level(level) - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - - for shape in shapes_2sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - return math_instructions_1sm, math_instructions_2sm - -def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): - """ - Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - enable_runtime_dtype: Whether to generate runtime dtype math instructions. - enable_compile_time_dtype: Whether to generate compile time dtype math instructions. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - - tcgen05_level = get_tcgen05_level_from_global_level(level) - pruning_level = get_pruning_level_from_global_level(level) - - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - - if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): - continue - - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e4m3, - DataType.e5m2, - DataType.e3m2, - DataType.e2m3, - DataType.e2m1 ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - - for shape in shapes_2sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - - if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): - continue - - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e4m3, - DataType.e5m2, - DataType.e3m2, - DataType.e2m3, - DataType.e2m1 ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - return math_instructions_1sm, math_instructions_2sm - -def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): - """ - Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. - - Args: - level: The global level to generate math instructions for. - enable_runtime_dtype: Whether to generate runtime dtype math instructions. - enable_compile_time_dtype: Whether to generate compile time dtype math instructions. - - Returns: - A tuple of two lists of MathInstruction objects. - The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. - """ - tcgen05_level = get_tcgen05_level_from_global_level(level) - math_instructions_1sm = [] - math_instructions_2sm = [] - - shapes_1sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level - ] - shapes_2sm = [ - shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level - ] - - for shape in shapes_1sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) - ) - - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e2m1, - ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - math_instructions_1sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) - ) - - - for shape in shapes_2sm: - if enable_runtime_dtype: - - runtime_types = [ DataType.f4 ] - - for a_type, b_type in product(runtime_types, repeat=2): - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) - ) - - - if enable_compile_time_dtype: - compile_time_types = [ DataType.e2m1, - ] - - for a_type, b_type in product(compile_time_types, repeat=2): - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - math_instructions_2sm.append( - MathInstruction( - shape, - a_type, b_type, DataType.f32, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) - ) - - - return math_instructions_1sm, math_instructions_2sm - - -def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): - """ - Generate all cluster shapes for SM100 at or above the given level. - - Args: - level: The global level to generate cluster shapes for. - - Returns: - A tuple of two lists of cluster shapes. - The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. - """ - cluster_level = get_cluster_level_from_global_level(level) - - assert cluster_level >= 4 - - if change_priority_func is not None: - SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) - SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) - change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) - shapes_1sm = [ - list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level - ] - shapes_2sm = [ - list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level - ] - - return shapes_1sm, shapes_2sm - - else: - - shapes_1sm = [ - list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level - ] - shapes_2sm = [ - list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level - ] - - return shapes_1sm, shapes_2sm diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py deleted file mode 100644 index e14761aae6494f877e6dc6521b30baea0db7509c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py +++ /dev/null @@ -1,212 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels. -These shape and level pairs are defined as dicts, where keys are shapes and values are their -associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster -size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. -Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. -Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted -when the `--kernel` argument is non-empty. -""" - -# NOTE: more combinations are possible here. -# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes. -# The rest are only used in the exhaustive mode (when the corresponding level digit is 9). -# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes. -SM90_MMA_MULTIPLIERS = { - (2, 1, 4): 0, - (1, 1, 4): 1, - (4, 1, 4): 2, - (2, 2, 4): 3, - (2, 1, 8): 4, - (4, 1, 8): 4, - (1, 1, 8): 4, - (2, 2, 8): 4, - (2, 1, 16): 5, - (4, 1, 16): 5, - (1, 1, 16): 5, - (2, 2, 16): 5, -} - -# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case -# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case -# Level 2: clusters with 1 or 2 CTAs -# Level 3: clusters with 1, 2, or 4 CTAs -# Level 4: clusters with 1, 2, 4, or 8 CTAs -# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs -SM90_CLUSTER_SIZES = { - (1, 2, 1): 0, - (2, 1, 1): 1, - (1, 1, 1): 2, - (2, 2, 1): 3, - (1, 4, 1): 3, - (4, 1, 1): 3, - (2, 4, 1): 4, - (4, 2, 1): 4, - (1, 8, 1): 4, - (8, 1, 1): 4, - (4, 4, 1): 5, -} - - -# WGMMA shapes -# Level 0: "default" shape only, -# Level 1: additional shapes for the unpruned case (tf32 only) -# Level 2: shapes that are all powers of 2 -# Level 3: all other shapes -SM90_WGMMA_SHAPES_FP16_BF16_DENSE = { - (64, 8, 16): 2, - (64, 16, 16): 2, - (64, 24, 16): 3, - (64, 32, 16): 2, - (64, 40, 16): 3, - (64, 48, 16): 3, - (64, 56, 16): 3, - (64, 64, 16): 2, - (64, 72, 16): 3, - (64, 80, 16): 3, - (64, 88, 16): 3, - (64, 96, 16): 3, - (64, 104, 16): 3, - (64, 112, 16): 3, - (64, 120, 16): 3, - (64, 128, 16): 0, - (64, 136, 16): 3, - (64, 144, 16): 3, - (64, 152, 16): 3, - (64, 160, 16): 3, - (64, 168, 16): 3, - (64, 176, 16): 3, - (64, 184, 16): 3, - (64, 192, 16): 3, - (64, 200, 16): 3, - (64, 208, 16): 3, - (64, 216, 16): 3, - (64, 224, 16): 3, - (64, 232, 16): 3, - (64, 240, 16): 3, - (64, 248, 16): 3, - (64, 256, 16): 1, -} - -SM90_WGMMA_SHAPES_TF32_DENSE = { - (64, 8, 8): 2, - (64, 16, 8): 2, - (64, 24, 8): 3, - (64, 32, 8): 2, - (64, 40, 8): 3, - (64, 48, 8): 3, - (64, 56, 8): 3, - (64, 64, 8): 2, - (64, 72, 8): 3, - (64, 80, 8): 3, - (64, 88, 8): 3, - (64, 96, 8): 3, - (64, 104, 8): 3, - (64, 112, 8): 3, - (64, 120, 8): 3, - (64, 128, 8): 0, - (64, 136, 8): 3, - (64, 144, 8): 3, - (64, 152, 8): 3, - (64, 160, 8): 3, - (64, 168, 8): 3, - (64, 176, 8): 3, - (64, 184, 8): 3, - (64, 192, 8): 3, - (64, 200, 8): 3, - (64, 208, 8): 3, - (64, 216, 8): 3, - (64, 224, 8): 3, - (64, 232, 8): 3, - (64, 240, 8): 3, - (64, 248, 8): 3, - (64, 256, 8): 1, -} - -SM90_WGMMA_SHAPES_FP8_DENSE = { - (64, 8, 32): 2, - (64, 16, 32): 2, - (64, 24, 32): 3, - (64, 32, 32): 2, - (64, 40, 32): 3, - (64, 48, 32): 3, - (64, 56, 32): 3, - (64, 64, 32): 2, - (64, 72, 32): 3, - (64, 80, 32): 3, - (64, 88, 32): 3, - (64, 96, 32): 3, - (64, 104, 32): 3, - (64, 112, 32): 3, - (64, 120, 32): 3, - (64, 128, 32): 0, - (64, 136, 32): 3, - (64, 144, 32): 3, - (64, 152, 32): 3, - (64, 160, 32): 3, - (64, 168, 32): 3, - (64, 176, 32): 3, - (64, 184, 32): 3, - (64, 192, 32): 3, - (64, 200, 32): 3, - (64, 208, 32): 3, - (64, 216, 32): 3, - (64, 224, 32): 3, - (64, 232, 32): 3, - (64, 240, 32): 3, - (64, 248, 32): 3, - (64, 256, 32): 1, -} - -SM90_WGMMA_SHAPES_INT8_DENSE = { - (64, 8, 32): 2, - (64, 16, 32): 2, - (64, 24, 32): 3, - (64, 32, 32): 2, - (64, 48, 32): 3, - (64, 64, 32): 2, - (64, 80, 32): 3, - (64, 96, 32): 3, - (64, 112, 32): 3, - (64, 128, 32): 0, - (64, 144, 32): 3, - (64, 160, 32): 3, - (64, 176, 32): 3, - (64, 192, 32): 3, - (64, 208, 32): 3, - (64, 224, 32): 3, - (64, 240, 32): 3, - (64, 256, 32): 1, -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py deleted file mode 100644 index fc5fdf14abb85835f71ecfd704a2738f5792af50..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py +++ /dev/null @@ -1,753 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for enumerating CUTLASS library SM90 kernels -""" - -import argparse -import enum -from itertools import product -import math -import logging -import os.path -import shutil -import sys -import copy -from typing import Any, Optional, Sequence, Tuple, List - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - -# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py -def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): - - # by default, use the latest CUDA Toolkit version - cuda_version = [11, 0, 132] - - # Update cuda_version based on parsed string - if semantic_ver_string != '': - for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): - if i < len(cuda_version): - cuda_version[i] = x - else: - cuda_version.append(x) - return cuda_version >= [major, minor, patch] - -#### Step 0: define levels - -# One integer level controls multiple "generators" and how many -# combinations they generate. That is the "global" level. -# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and -# anything that is eventually involved in the Cartesian product -# which yields our kernel configurations. -# For simplicity, each generator defines their own levels, -# starting from 0. As a rule we assume 10 or fewer levels, making -# their level a digit. -# The "global" level simply stacks these digits and represents them -# as a single integer. -# -# For example, level 500 indicates cluster sizes are at level 5, MMA -# multipliers are at level 0, and WGMMA shapes are at level 0 as well. -# -# Here we define the global level to generator level mappings. - - -def get_wgmma_level_from_global_level(global_level: int): - return global_level % 10 - - -def get_mma_level_from_global_level(global_level: int): - return (global_level // 10) % 10 - - -def get_cluster_level_from_global_level(global_level: int): - return (global_level // 100) % 10 - - -def get_pruning_level_from_global_level(global_level: int): - return (global_level // 1000) % 10 - - -#### Step 1: generate MMA instruction shapes based on levels - -try: - from .sm90_shapes import ( - SM90_MMA_MULTIPLIERS, - SM90_CLUSTER_SIZES, - SM90_WGMMA_SHAPES_TF32_DENSE, - SM90_WGMMA_SHAPES_FP16_BF16_DENSE, - SM90_WGMMA_SHAPES_FP8_DENSE, - SM90_WGMMA_SHAPES_INT8_DENSE, - ) -except: - from sm90_shapes import ( - SM90_MMA_MULTIPLIERS, - SM90_CLUSTER_SIZES, - SM90_WGMMA_SHAPES_TF32_DENSE, - SM90_WGMMA_SHAPES_FP16_BF16_DENSE, - SM90_WGMMA_SHAPES_FP8_DENSE, - SM90_WGMMA_SHAPES_INT8_DENSE, - ) - - -def generate_tf32_math_instruction_shapes_sm90(level: int): - assert isinstance(level, int) and level >= 0 - filtered_list_of_wgmma_shapes = [ - wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level - ] - return filtered_list_of_wgmma_shapes - -def generate_fp16_bf16_math_instruction_shapes_sm90(level: int): - assert isinstance(level, int) and level >= 0 - filtered_list_of_wgmma_shapes = [ - wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level - ] - return filtered_list_of_wgmma_shapes - -def generate_fp8_math_instruction_shapes_sm90(level: int): - assert isinstance(level, int) and level >= 0 - filtered_list_of_wgmma_shapes = [ - wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level - ] - return filtered_list_of_wgmma_shapes - -def generate_int8_math_instruction_shapes_sm90(level: int): - assert isinstance(level, int) and level >= 0 - filtered_list_of_wgmma_shapes = [ - wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level - ] - return filtered_list_of_wgmma_shapes - -def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType): - # DataTypeSize are in the unit of bits - a_bytes = DataTypeSize[a_type] // 8 - b_bytes = DataTypeSize[b_type] // 8 - if a_bytes == 4 or b_bytes == 4: - return generate_tf32_math_instruction_shapes_sm90(wgmma_level) - elif a_bytes == 2 or b_bytes == 2: - return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level) - else: - return generate_fp8_math_instruction_shapes_sm90(wgmma_level) - -########### - -def generate_tf32_math_instructions_sm90(level: int): - wgmma_level = get_wgmma_level_from_global_level(level) - math_instructions = [] - for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level): - math_instructions.append( - MathInstruction( - math_instruction_shape, - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - return math_instructions - -def generate_fp16_bf16_math_instructions_sm90(level: int): - wgmma_level = get_wgmma_level_from_global_level(level) - math_instructions = [] - for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level): - math_instructions += [ - MathInstruction( - math_instruction_shape, - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - return math_instructions - -def generate_fp8_math_instructions_sm90(level: int): - wgmma_level = get_wgmma_level_from_global_level(level) - math_instructions = [] - for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level): - math_instructions += [ - MathInstruction( - math_instruction_shape, - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.e5m2, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - return math_instructions - -def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]): - wgmma_level = get_wgmma_level_from_global_level(level) - math_instructions = [] - for a_type, b_type, acc_type in types_of_a_b_acc: - math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type) - for math_instruction_shape in math_instruction_shapes: - math_instructions += [ - MathInstruction( - math_instruction_shape, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add - ), - ] - return math_instructions - -def generate_int8_math_instructions_sm90(level: int): - wgmma_level = get_wgmma_level_from_global_level(level) - math_instructions = [] - for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level): - math_instructions += [ - MathInstruction( - math_instruction_shape, - DataType.s8, DataType.s8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - math_instruction_shape, - DataType.u8, DataType.u8, DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - return math_instructions - -def make_sparse_math_instructions(math_instructions): - sparse_instructions = [] - for inst in math_instructions: - if inst.opcode_class == OpcodeClass.TensorOp: - sparse_instructions.append(MathInstruction( - (inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2), - inst.element_a, inst.element_b, inst.element_accumulator, - OpcodeClass.SparseTensorOp, - inst.math_operation),) - return sparse_instructions - - -#### Step 2: generate tile descriptions from math instruction shapes - -def is_tile_desc_valid(tile_description): - if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90: - return False - - element_a, element_b, element_accum = ( - tile_description.math_instruction.element_a, - tile_description.math_instruction.element_b, - tile_description.math_instruction.element_accumulator - ) - - cluster_size, cta_shape = ( - tile_description.cluster_shape, - tile_description.threadblock_shape, - ) - grid_size = ( - cta_shape[0] * cluster_size[0] + - cta_shape[1] * cluster_size[1] + - cta_shape[2] * cluster_size[2] - ) - num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2] - cluster_shape = ( - cluster_size[0] * cta_shape[0], - cluster_size[1] * cta_shape[1], - cluster_size[2] * cta_shape[2] - ) - - FP32_TYPES = [DataType.f32, DataType.tf32] - FP16_TYPES = [DataType.f16, DataType.bf16] - is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES - is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES - - # Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is - # allowed for non portable clusters. - if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1: - return False - - if grid_size < 1: - return False - - # SM90 WGMMA shapes are always 64 across M, therefore - # CTA shape across M must always be a multiple of 64. - if cta_shape[0] < 64 or cta_shape[0] % 64 != 0: - return False - - # The minimum WGMMA shape across N is 8, and increments - # vary across different dtypes, but they're never smaller - # than 8. The minimum CTA shape allowed across N though is 16. - if cta_shape[1] < 16 or cta_shape[1] % 8 != 0: - return False - - # SM90 WGMMA shapes across K are always 8 for 32 bit dense - # operations, 16 for 16 bit, and 32 for 8 bit. In any case, - # the CTA shape across K should be a multiple of 8 and at least - # twice the WGMMA shape across K. - if cta_shape[2] < 16 or cta_shape[2] % 8 != 0: - return False - - # Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs) - if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256: - return False - - if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128: - return False - - if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64: - return False - - if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128: - return False - - # CTA shape upper bound: <256, 256, 256> - if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256: - return False - - return True - -def get_mma_multipliers(level: int): - assert isinstance(level, int) and level >= 0 - mma_level = get_mma_level_from_global_level(level) - return [ - mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level - ] - -def get_cluster_sizes(level: int, is_aligned: bool): - if not is_aligned: - return [(1, 1, 1)] - assert isinstance(level, int) and level >= 0 - cluster_level = get_cluster_level_from_global_level(level) - return [ - cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level - ] - -def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int): - tile_descriptions = set() - mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned) - for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes): - - # generator can stamp out duplicate kernels, because it doesn't explicitly set instruction - # shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using - # the auto kernel schedule. - - math_inst_stub = copy.deepcopy(math_inst) - math_inst_stub.instruction_shape = [0, 0, 0] - - tile_desc = TileDescription( - threadblock_shape=[ - math_inst.instruction_shape[0] * mma_mul[0], - math_inst.instruction_shape[1] * mma_mul[1], - math_inst.instruction_shape[2] * mma_mul[2] - ], - stages=0, - warp_count=[4, 1, 1], - math_instruction=math_inst_stub, - min_compute=90, - max_compute=90, - cluster_shape=cluster_size) - # For sparse kernels K-tile is twice as large (due to 2x MMA-K size) - # Reduce it to same size as dense to afford more smem stages - if math_inst.opcode_class == OpcodeClass.SparseTensorOp: - tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2 - if is_tile_desc_valid(tile_desc): - tile_descriptions.add(tile_desc) - - return tile_descriptions - -#### Step 3: map tile description to valid schedules - -def is_tile_desc_compatible_with_cooperative(tile_description): - # Cooperative kernels require a minimum CTA-M of 128 - return tile_description.threadblock_shape[0] % 128 == 0 - - -def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): - dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = ( - data_types["a_type"], - data_types["b_type"], - data_types["c_type"], - data_types["d_type"], - data_types["acc_type"], - data_types["epi_type"] - ) - mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1] - bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d] - - shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn - shmem_bits_total = shmem_bits_c + shmem_bits_d - # Magic number: 2^20 - # Existing logic suggested that tile shape 256x128 (or 128x256) - # would run out of shmem if D is FP32, and source is needed. - # That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit. - # Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB. - # Since epilogue can't possibly use ALL of the shmem available - # we can just settle on 2^20 bits (~ 131 KB) being the upper bound - # we would allow for epilogue. - # This can be different for non-persistent kernels where epilogue and - # mainloop shmem is shared. - if shmem_bits_total > 2 ** 20: - return False - - return True - - -def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout, - instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x): - # Level 0: prune according to existing generator.py behavior - # Level >= 1: no pruning - level = get_pruning_level_from_global_level(instantiation_level) - schedules = [] - stream_k_schedules = [] - - if not is_tile_desc_valid(tile_description): - return schedules, stream_k_schedules - - FP16_TYPES = [DataType.f16, DataType.bf16] - is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES - - FP8_TYPES = [DataType.e4m3, DataType.e5m2] - is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES - can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc - - FP32_TYPES = [DataType.f32, DataType.tf32] - is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES - requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor - - can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description) - can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types) - - default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed - auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed - - cta_m, cta_n, cta_k = ( - tile_description.threadblock_shape[0], - tile_description.threadblock_shape[1], - tile_description.threadblock_shape[2] - ) - c_type = data_types["c_type"] - d_type = data_types["d_type"] - is_void_c = c_type == DataType.void - - # Filter out invalid kernels - is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor - is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor - is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor - - # static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, - # "Copy size must evenly divide SMEM tile."); - if is_fp32 and is_nt and (cta_n % cta_k != 0): - return [], [] - - # static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits::value))) == 128, - # "SmemLayoutB K must be 128bytes to be transposed.") - if is_fp32 and is_nt and cta_k != 32: - return [], [] - - # Static assert failure when instantiating SmemLayoutB - if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): - return [], [] - - grouped = is_grouped(gemm_kind) - if grouped: - # the following cases are unsupported by grouped GEMM - if not is_aligned: - return [], [] - if requires_transposed_epilogue: - return [], [] - - # Early pruning - if level < 1: - # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64 - if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64: - return [], [] - - # FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules - is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128 - if is_large_fp8_tile: - # Only void-C, and only FP8 outputs allowed - if not is_void_c or d_type not in FP8_TYPES: - return [], [] - if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: - schedules = [] - if is_blockwise(gemm_kind): - schedules.append( - [ - to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - else: - schedules.append( - [ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - schedules.append( - [ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - return schedules, [] - return [], [] - - if is_fp8 and not is_large_fp8_tile: - valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void] - # Prune all configs with fp8 source, and all configs with non-fp8 output - # that have different dtypes for source and output. - if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type): - return [], [] - - # FP32/TF32 kernels don't stamp out void-C - if is_fp32 and is_void_c: - return [], [] - - # Void-c only makes a difference for TMA epilogues - if is_void_c and not can_do_tma_epilogue: - return [], [] - - # For mixed input data types - a_type_size = DataTypeSize[data_types["a_type"]] - b_type_size = DataTypeSize[data_types["b_type"]] - if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1): - schedules = [] - stream_k_schedules = [] - epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized - if a_type_size > b_type_size: - epilogue_schedule = EpilogueScheduleType.EpilogueTransposed - - if not is_blockwise(gemm_kind): - schedules.append([ - KernelScheduleType.TmaWarpSpecialized, - epilogue_schedule - ]) - schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule - ]) - if cta_m >= 128: - if a_type_size > b_type_size: - epilogue_schedule = EpilogueScheduleType.EpilogueTransposed - else: - epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative - if is_blockwise(gemm_kind): - schedules.append([ - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, - epilogue_schedule - ]) - else: - schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule - ]) - stream_k_schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule - ]) - return schedules, stream_k_schedules - - if not is_aligned and not is_blockwise(gemm_kind): - schedules = [[KernelScheduleType.CpAsyncWarpSpecialized, - default_epilogue]] - stream_k_schedules = [] - - if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative: - schedules.append([ - KernelScheduleType.CpAsyncWarpSpecializedCooperative, - default_epilogue - ]) - stream_k_schedules.append([ - KernelScheduleType.CpAsyncWarpSpecializedCooperative, - default_epilogue - ]) - - return schedules, stream_k_schedules - - schedules = [] - # Pruning: emit Void-C and Grouped kernels with persistent kernels only - if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind): - # Pruning: don't stamp out fp8 kernels with auto schedule - if not is_fp8: - schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue]) - schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) - stream_k_schedules = [] - - if CudaToolkitVersionSatisfies(cuda_version, 12, 0): - if can_do_tma_epilogue: - assert not requires_transposed_epilogue - # Inconsistency: fp8 pingpong only gets stamped out with fast accum - if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind): - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) - ]) - if can_do_fp8_fast_accum: - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) - ]) - - if CudaToolkitVersionSatisfies(cuda_version, 12, 1): - # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue - if not is_fp8 or level >= 1: - if not is_blockwise(gemm_kind): - schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) - else: - schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) - - if can_do_fp8_fast_accum: - if not grouped: - schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) - schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)]) - - if can_do_cooperative: - if is_blockwise(gemm_kind): - schedules.append([ - to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(default_epilogue, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, - default_epilogue - ]) - else: - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(default_epilogue, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - default_epilogue - ]) - if can_do_fp8_fast_accum: - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), - to_grouped_schedule(default_epilogue, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, - default_epilogue - ]) - - # persistent kernels with TMA epilogues - if can_do_tma_epilogue: - assert not requires_transposed_epilogue - if can_do_cooperative: - if is_blockwise(gemm_kind): - schedules.append([ - to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ]) - else: - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ]) - if can_do_fp8_fast_accum: - schedules.append([ - to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), - to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) - ]) - stream_k_schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ]) - # Grouped GEMM do not support Stream-K scheduler - if grouped: - return schedules, [] - return schedules, stream_k_schedules - - -#### Misc: helpers - -def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None): - element_a, element_b = math_instruction.element_a, math_instruction.element_b - element_accumulator = math_instruction.element_accumulator - element_c = element_source or element_accumulator - element_d = element_dest or element_accumulator - element_epilogue = element_epilogue or element_accumulator - data_types = { - "a_type" : element_a, - "b_type" : element_b, - "c_type" : element_c, - "d_type" : element_d, - "acc_type" : element_accumulator, - "epi_type" : element_epilogue - } - return data_types - -def fix_alignments(data_types, layout, alignment_bits = 128): - operand_keys = ["a_type", "b_type", "c_type"] - operands_to_fix = ["c_type"] - new_layout = [] - assert len(layout) == len(operand_keys) - for i, k in enumerate(operand_keys): - assert k in data_types and data_types[k] in DataTypeSize - dtype = data_types[k] - dtype_size_bits = DataTypeSize[dtype] - - layout_type = layout[i][0] - layout_alignment = layout[i][1] - - # Don't modify alignment if dtype's been changed to void - if k in operands_to_fix and dtype_size_bits >= 1: - layout_alignment = alignment_bits // dtype_size_bits - - new_layout.append([layout_type, layout_alignment]) - - return new_layout diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py deleted file mode 100644 index 8661ff798b2e3e0987fdf7e050b6ad2e0f8f3678..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py +++ /dev/null @@ -1,440 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting Symm kernels -""" - -import enum -import functools -import operator -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - - -################################################################################################### -# -# Data structure modeling a Symm update operation -# -################################################################################################### - -# -class SymmOperation: - # - def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ - blas_mode = BlasMode.symmetric): - - self.blas_mode = blas_mode - self.operation_kind = OperationKind.Symm - self.arch = arch - self.tile_description = tile_description - self.symm_kind = symm_kind - # tensor A and B have same data type and layout - self.A = A - self.B = B - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - return False - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def is_planar_complex(self): - return False - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and' - } - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm' - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] - ) - return "%s" % (ShortLayoutTypeNames[self.A.layout]) - - # - def side_mode_name(self): - return "%s" % (ShortSideModeNames[self.A.side_mode]) - - # - def fill_mode_name(self): - return "%s" % (ShortFillModeNames[self.A.fill_mode]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = self.C.alignment - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'side_mode': self.side_mode_name(), - 'fill_mode': self.fill_mode_name(), - 'alignment': "%d" % alignment, - } - ) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -# -class EmitSymmUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self): - self.symm_template = """ -// Symm operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Symm< - ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation} ->; -""" - self.symm_complex_template = """ -// Symm operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Symm< - ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation}, - ${blas_mode} ->; -""" - - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - - warp_count = operation.tile_description.warp_count - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'side_mode': SideModeTag[operation.A.side_mode], - 'fill_mode': FillModeTag[operation.A.fill_mode], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'split_k_serial': 'false', - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'blas_mode': BlasModeTag[operation.blas_mode] - } - - symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template - - return SubstituteTemplate(symm_template, values) - -################################################################################################### - - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitSymmConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - SymmKind.Universal: EmitSymmUniversalInstance, - } - - self.symm_kind_wrappers = { - SymmKind.Universal: 'SymmOperation', - } - - self.instance_template = { - SymmKind.Universal: """ -${compile_guard_start} - manifest.append(new ${symm_kind}< - Operation_${operation_name} - >("${operation_name}")); -${compile_guard_end} -""" - } - - self.header_template = """ -/* - Generated by symm_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -#include "symm_operation.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.initialize_function_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_${configuration_name}(Manifest &manifest) { - -""" - self.epilogue_template = """ - -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - emitter = self.instance_emitter[operation.symm_kind]() - - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'symm_kind': self.symm_kind_wrappers[operation.symm_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py deleted file mode 100644 index 46ba360cb615c955d329b390c0ab93d13ed88c7c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py +++ /dev/null @@ -1,447 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for emitting Trmm kernels -""" - -import enum -import functools -import operator -import os.path -import shutil - -try: - import builtins - if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: - raise ImportError("Disabling attempt to import cutlass_library") - from cutlass_library.library import * -except ImportError: - from library import * - - -################################################################################################### -# -# Data structure modeling a TRMM operation -# -################################################################################################### - -# -class TrmmOperation: - # - def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): - - self.operation_kind = OperationKind.Trmm - self.arch = arch - self.tile_description = tile_description - self.trmm_kind = trmm_kind - self.A = A - self.B = B - self.C = C - self.element_epilogue = element_epilogue - self.epilogue_functor = epilogue_functor - self.swizzling_functor = swizzling_functor - - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 - ] - return self.tile_description.math_instruction.math_operation in complex_operators - return False - - # - def is_planar_complex(self): -# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray) - return False - - # - def is_mixed_input(self): - return self.A.element != self.B.element - - # - def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator - - if self.is_complex(): - return get_complex_from_real(accum) - - return accum - - # - def short_math_name(self): - if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: - return "g%s" % ShortDataTypeNames[self.accumulator_type()] - return ShortDataTypeNames[self.accumulator_type()] - - - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - - inst_shape = '' - inst_operation = '' - intermediate_type = '' - - math_operations_map = { - MathOperation.xor_popc: 'xor', - MathOperation.and_popc: 'and' - } - - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - - math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - inst_shape += math_op_string - - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: - intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] - - return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind]) - - # - def extended_name(self): - ''' Append data types if they differ from compute type. ''' - if self.is_complex(): - extended_name = "${core_name}" - else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: - extended_name = "${core_name}" - - extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() - }) - - return extended_name - - # - def layout_name(self): - if self.is_complex() or self.is_planar_complex(): - return "%s%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] - ) - return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) - - # - def side_mode_name(self): - return "%s" % (ShortSideModeNames[self.A.side_mode]) - - # - def fill_mode_name(self): - return "%s" % (ShortFillModeNames[self.A.fill_mode]) - - # - def diag_type_name(self): - return "%s" % (ShortDiagTypeNames[self.A.diag_type]) - - # - def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - threadblock = self.tile_description.procedural_name() - - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - - alignment = max([self.C.alignment]) - - return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}", - { - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'side_mode': self.side_mode_name(), - 'fill_mode': self.fill_mode_name(), - 'diag_type': self.diag_type_name(), - 'alignment': "%d" % self.C.alignment, - } - ) - - # - def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - return self.procedural_name() - -################################################################################################### -# -# Emits single instances of a CUTLASS device-wide operator -# -################################################################################################### - -# -class EmitTrmmUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self): - self.trmm_template = """ -// Trmm operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Trmm< - ${element_a}, ${layout_a}, - ${side_mode}, ${fill_mode}, ${diag_type}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue}, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation} ->; -""" - self.trmm_complex_template = """ -// Trmm operator ${operation_name} -using Operation_${operation_name} = - typename cutlass::gemm::device::Trmm< - ${element_a}, ${layout_a}, - ${side_mode}, ${fill_mode}, ${diag_type}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue}, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - ${split_k_serial}, - ${math_operation}, - ${transform_a} ->; -""" - - def emit(self, operation): - - threadblock_shape = operation.tile_description.threadblock_shape - warp_count = operation.tile_description.warp_count - - warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'side_mode' : SideModeTag[operation.A.side_mode], - 'fill_mode': FillModeTag[operation.A.fill_mode], - 'diag_type' : DiagTypeTag[operation.A.diag_type], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], - 'stages': str(operation.tile_description.stages), - 'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes - 'align_b': str(operation.B.alignment), - 'split_k_serial': 'false', - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'transform_a': ComplexTransformTag[operation.A.complex_transform] - } - - trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template - - return SubstituteTemplate(trmm_template, values) - -################################################################################################### - - -################################################################################################### -# -# Emitters functions for all targets -# -################################################################################################### - -class EmitTrmmConfigurationLibrary: - def __init__(self, operation_path, configuration_name): - self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') - - self.instance_emitter = { - TrmmKind.Universal: EmitTrmmUniversalInstance, - } - - self.trmm_kind_wrappers = { - TrmmKind.Universal: 'TrmmOperation', - } - - self.instance_template = { - TrmmKind.Universal: """ -${compile_guard_start} - manifest.append(new ${trmm_kind}< - Operation_${operation_name} - >("${operation_name}")); -${compile_guard_end} -""" - } - - self.header_template = """ -/* - Generated by trmm_operation.py - Do not edit. -*/ - -/////////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "library_internal.h" -#include "trmm_operation.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - self.initialize_function_template = """ - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_${configuration_name}(Manifest &manifest) { - -""" - self.epilogue_template = """ - -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -""" - - def __enter__(self): - self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(self.header_template) - - self.instance_definitions = [] - self.instance_wrappers = [] - - self.operations = [] - return self - - def emit(self, operation): - emitter = self.instance_emitter[operation.trmm_kind]() - - self.operations.append(operation) - - self.instance_definitions.append(emitter.emit(operation)) - - self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name(), - 'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind], - 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", - 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" - })) - - def __exit__(self, exception_type, exception_value, traceback): - - # Write instance definitions in top-level namespace - for instance_definition in self.instance_definitions: - self.configuration_file.write(instance_definition) - - # Add wrapper objects within initialize() function - self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { - 'configuration_name': self.configuration_name - })) - - for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) - - self.configuration_file.write(self.epilogue_template) - self.configuration_file.close() - -################################################################################################### diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py deleted file mode 100644 index c396d75a5534493f1ebf90043f2a182eb46abb7f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py +++ /dev/null @@ -1,132 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -sys.path.insert(0, os.path.abspath('..')) -sys.path.insert(0, os.path.abspath('../..')) -sys.path.insert(0, os.path.abspath('../../media/docs')) - -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information - -project = 'CUTLASS Python interface' -copyright = '2023, NVIDIA' -author = 'NVIDIA' -release = '3.1.0' - -# -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration - - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'myst_parser', - 'nbsphinx', - 'nbsphinx_link', - 'sphinx_copybutton', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosectionlabel', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.extlinks', - 'sphinx.ext.ifconfig', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx_inline_tabs', - ] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -autodoc_typehints = 'description' - -pygments_style = "sphinx" -pygments_dark_style = "monokai" - -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# Ignore errors when converting notebooks -nbsphinx_allow_errors = True - -language = 'en' -# -- Options for HTML output ------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output - -html_static_path = ['_static'] - -html_title = "CUTLASS Python" -html_baseurl = 'docs' -html_theme = 'furo' -html_theme_options = { - "light_logo": "cutlass-logo-small.png", - "dark_logo": "cutlass-logo-small.png", - "light_css_variables": { - "color-brand-primary": "#76B900", - "color-brand-content": "#76B900", - }, - "dark_css_variables": { - "color-brand-primary": "#76B900", - "color-brand-content": "#76B900", - }, - "footer_icons": [ - { - "name": "GitHub", - "url": "https://github.com/NVIDIA/cutlass", - "html": """ - - - - """, - "class": "", - }, - ], -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py deleted file mode 100644 index 308a5676b06f00089d1cdfe0fb83b442ca2df36e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from .int_tuple import * -from .layout import * -from .swizzle import * -from .typing import * diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py deleted file mode 100644 index 3d722130c52142e68a3bcd54ac708012aeeeaad3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py +++ /dev/null @@ -1,225 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Functions for manipulating IntTuples -""" - -from functools import reduce -from itertools import chain -from typing import Union -from .typing import Integer - - -def is_int(x): - return isinstance(x, Integer) - - -def is_tuple(x): - return isinstance(x, tuple) - - -def flatten(t): - if is_tuple(t): - if len(t) == 0: - return () - else: - return tuple(i for a in t for i in flatten(a)) - else: - return (t,) - - -def signum(a): - return bool(a > 0) - bool(a < 0) - - -def product(a): - if is_tuple(a): - return reduce(lambda val,elem : val*product(elem), a, 1) - else: - return a - - -def inner_product(a, b): - if is_tuple(a): # tuple tuple - assert len(a) == len(b) - return sum(inner_product(x,y) for x,y in zip(a,b)) - else: # "int" "int" - assert not is_tuple(b) - return a * b - - -def tuple_max(a): - if is_tuple(a): - return max(tuple_max(x) for x in a) - else: - return a - - -def elem_scale(a, b): - if is_tuple(a): - if is_tuple(b): # tuple tuple - assert len(a) == len(b) - return tuple(elem_scale(x,y) for x,y in zip(a,b)) - else: # tuple "int" - assert False # Error - else: - if is_tuple(b): # "int" tuple - return elem_scale(a, product(b)) - else: # "int" "int" - return a * b - - -# Inclusive prefix ceil div with output congruent to input a -def shape_div(a, b): - if is_tuple(a): - if is_tuple(b): # tuple tuple - assert len(a) == len(b) - return tuple(shape_div(x,y) for x,y in zip(a,b)) - else: # tuple "int" - #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] - r = [] - for v in a: - r.append(shape_div(v,b)) - b = shape_div(b,product(v)) - return tuple(r) - else: - if is_tuple(b): # "int" tuple - return shape_div(a, product(b)) - else: # "int" "int" - assert a % b == 0 or b % a == 0 - return (a + b - 1) // b - -# Exclusive prefix product with output congruent to input a -def prefix_product(a, init=1): - if is_tuple(a): - if is_tuple(init): # tuple tuple - assert len(a) == len(init) - return tuple(prefix_product(x,i) for x,i in zip(a,init)) - else: # tuple "int" - #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))] - r = [] - for v in a: - r.append(prefix_product(v,init)) - init = init * product(v) - return tuple(r) - else: - if is_tuple(init): # "int" tuple - assert False # Error - else: # "int" "int" - return init - - -def idx2crd(idx, shape, stride=None): - if stride is None: - stride = prefix_product(shape) - - if is_tuple(idx): - if is_tuple(shape): # tuple tuple tuple - assert len(idx) == len(shape) and len(idx) == len(stride) - return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride)) - else: # tuple "int" "int" - assert False # Error - else: - if is_tuple(shape): # "int" tuple tuple - assert len(shape) == len(stride) - return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride)) - else: # "int" "int" "int" - return (idx // stride) % shape - - -def crd2idx(crd, shape, stride=None): - if stride is None: - stride = prefix_product(shape) - - if is_tuple(crd): - if is_tuple(shape): # tuple tuple tuple - assert len(crd) == len(shape) and len(crd) == len(stride) - return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) - else: # tuple "int" "int" - assert False, f"crd={crd}, shape={shape}" # Error - else: - if crd is None: - crd = 0 - - if is_tuple(shape): # "int" tuple tuple - assert len(shape) == len(stride) - result = 0 - for i in range(len(shape)-1): - result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) - crd = crd // product(shape[i]) - return result + crd2idx(crd, shape[-1], stride[-1]) - else: # "int" "int" "int" - return crd * stride - - -# Transform crd into the dst_shape's iteration space -def crd2crd(crd, dst_shape, src_shape=None): - if is_tuple(crd): - if is_tuple(dst_shape): # tuple tuple - assert len(crd) == len(dst_shape) - return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape)) - else: # tuple "int" - # Ambiguous unless we have src_shape - assert src_shape is not None - return crd2idx(crd, src_shape) - else: - if is_tuple(dst_shape): # "int" tuple - return idx2crd(crd, dst_shape) - else: # "int" "int" - assert crd < dst_shape - return crd - - -# Filter trg according to crd: keep only elements of trg that are paired with None -def slice_(crd: Union[None, tuple, int], - trg: Union[tuple, int]): - if is_tuple(crd): - if is_tuple(trg): # tuple tuple - assert len(crd) == len(trg) - # match C++ behavior of `filter_tuple` using `tuple_cat(...)` - return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)]))) - else: - assert False # tuple "int" : Error - elif crd is None: - # match C++ behavior `return cute::tuple{b};` - return (trg,) - else: - return () - - -# Determine if None appears at any of an int_tuples' terminals -def has_none(a: Union[None, tuple, int]): - if is_tuple(a): - return any(has_none(v) for v in a) - else: - return a is None diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py deleted file mode 100644 index 7c220eb16dd089c65fdbe6d6929b357ace0a77c1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py +++ /dev/null @@ -1,367 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Definition of CuTe Layouts and functions to manipulate them -""" - -from itertools import chain -from typing import Union - -from .int_tuple import * - - -class LayoutBase: - pass - - -def is_layout(x): - return isinstance(x, LayoutBase) - - -class Layout(LayoutBase): - def __init__(self, _shape, _stride=None): - self.shape = _shape - if _stride is None: - self.stride = prefix_product(self.shape) - else: - self.stride = _stride - - # operator == - def __eq__(self, other): - return self.shape == other.shape and self.stride == other.stride - - # operator len(L) (len [rank] like tuples) - def __len__(self): - if is_tuple(self.shape): - return len(self.shape) - else: - return 1 - - # operator () (map coord to idx) - def __call__(self, *args): - """ - Map a logical coordinate to a linear index (Coord has no Underscore slice operators) - OR - Slice the layout and return the sublayout (Coord has an Underscore slice op) - - Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ - """ - if has_none(args): - if len(args) == 1: - return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) - else: - return Layout(slice_(args, self.shape), slice_(args, self.stride)) - else: - if len(args) == 1: - return crd2idx(args[0], self.shape, self.stride) - else: - return crd2idx(args, self.shape, self.stride) - - # operator [] (get-i like tuples) - def __getitem__(self, i): - if is_tuple(self.shape): - return Layout(self.shape[i], self.stride[i]) - else: - assert i == 0 - return Layout(self.shape, self.stride) - - # size(layout) Size of the domain - def size(self): - return product(self.shape) - - # cosize(layout) Size of the codomain - def cosize(self): - return self(self.size() - 1) + 1 - - # print and str - def __str__(self): - return f"{self.shape}:{self.stride}" - - # error msgs and representation - def __repr__(self): - return f"Layout({self.shape},{self.stride})" - - -# Make Layout from a list of layouts (each layout it's own mode in the result) -def make_layout(*layouts): - if len(layouts) == 1 and not is_layout(layouts[0]): - layouts = layouts[0] - - shape, stride = zip(*((a.shape,a.stride) for a in layouts)) - return Layout(shape, stride) - - -# Size of the domain -def size(layout): - if is_layout(layout): - return layout.size() - return product(layout) - - -# Size of the codomain -def cosize(layout): - return layout.cosize() - - -# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function -def coalesce(layout, profile=None): - if is_tuple(profile): - assert len(layout) >= len(profile) - return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), - (layout[i] for i in range(len(profile),len(layout))))) - - result_shape = [1] - result_stride = [0] - for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): - # skip their shape-1s - if shape == 1: - continue - # replace our shape-1 with anything - elif result_shape[-1] == 1: - result_shape[-1] = shape - result_stride[-1] = stride - # merge modes if the shape*stride match - elif result_shape[-1] * result_stride[-1] == stride: - result_shape[-1] = result_shape[-1] * shape - # append a new mode - else: - result_shape.append(shape) - result_stride.append(stride) - - if len(result_shape) == 1: - return Layout(result_shape[0], result_stride[0]) - else: - return Layout(tuple(result_shape), tuple(result_stride)) - - -# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them -def filter(layout, profile=None): - if is_tuple(profile): - assert len(layout) >= len(profile) - return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), - (layout[i] for i in range(len(profile),len(layout))))) - - result_shape = [] - result_stride = [] - for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): - # skip their shape-1s and stride-0s - if not (shape == 1 or stride == 0): - result_shape.append(shape) - result_stride.append(stride) - - if len(result_shape) == 0: - return Layout(1,0) - else: - return coalesce(Layout(tuple(result_shape), tuple(result_stride))) - - -# Layout composition -# Use tuples-of-layouts to perform this operation by-mode and None as no-op -def composition(layoutA, layoutB): - if layoutB is None: - return layoutA - elif is_int(layoutB): - return composition(layoutA, Layout(layoutB)) - elif is_tuple(layoutB): - assert len(layoutA) >= len(layoutB) - return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), - (layoutA[i] for i in range(len(layoutB),len(layoutA))))) - elif is_tuple(layoutB.shape): - return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) - - if layoutB.stride == 0: - return Layout(layoutB.shape, 0) - else: - result_shape = [] - result_stride = [] - rest_shape = layoutB.shape - rest_stride = layoutB.stride - flat_A = coalesce(layoutA) - for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]): - assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 - new_shape = min(max(1, curr_shape // rest_stride), rest_shape) - - if new_shape != 1: - result_shape.append(new_shape) - result_stride.append(rest_stride * curr_stride) - - rest_shape = rest_shape // new_shape - rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) - - if rest_shape != 1 or len(result_shape) == 0: - result_shape.append(rest_shape) - result_stride.append(rest_stride * flatten(flat_A.stride)[-1]) - - if len(result_shape) == 1: - return Layout(result_shape[0], result_stride[0]) - else: - return Layout(tuple(result_shape), tuple(result_stride)) - - -# Layout complement -def complement(layout, max_idx=1): - if is_int(layout): - return complement(Layout(layout)) - - result_shape = [] - result_stride = [] - current_idx = 1 - - sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) - for (stride, shape) in sorted_DS: - if stride == 0 or shape == 1: - continue - - in_bound = current_idx <= shape * stride - # To support symbolic value which can't be evaluated now - assert (type(in_bound) is not bool) or in_bound - - result_shape.append(stride // current_idx) - result_stride.append(current_idx) - current_idx = shape * stride - - result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div - result_stride.append(current_idx) - - return coalesce(Layout(tuple(result_shape), tuple(result_stride))) - - -# Layout right inverse -def right_inverse(layout): - if layout is None: - return None - elif is_int(layout): - return Layout(layout) - - result_shape = [] - result_stride = [] - current_idx = 1 - - flat_shape = flatten(layout.shape) - flat_stride = flatten(layout.stride) - sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) - for (stride,shape,rstride) in sorted_DSA: - if shape == 1: - continue - if current_idx != stride: - break - - result_shape.append(shape) - result_stride.append(rstride) - current_idx = shape * stride - - return coalesce(Layout(tuple(result_shape), tuple(result_stride))) - - -# Layout left inverse -def left_inverse(layout): - if layout is None: - return None - elif is_int(layout): - return Layout(layout) - return right_inverse(make_layout(layout, complement(layout))) - - -# Split a layout by the composition of B and the "rest" -# Use tuples-of-layouts to perform this operation by-mode and None as no-op -def logical_divide(layoutA, layoutB): - if layoutB is None: - return layoutA - elif is_int(layoutB): - return logical_divide(layoutA, Layout(layoutB)) - elif is_tuple(layoutB): - assert len(layoutA) >= len(layoutB) - return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), - (layoutA[i] for i in range(len(layoutB),len(layoutA))))) - - return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) - - -# Reproduce a layoutA over a layoutB -# Use tuples-of-layouts to perform this operation by-mode and None as no-op -def logical_product(layoutA, layoutB): - if layoutB is None: - return layoutA - elif is_int(layoutB): - return logical_divide(layoutA, Layout(layoutB)) - elif is_tuple(layoutB): - assert len(layoutA) >= len(layoutB) - return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), - (layoutA[i] for i in range(len(layoutB),len(layoutA))))) - - return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); - - -# Gather the modes from a hierarchical logical_divide or logical_product -def hier_unzip(splitter, layoutA, layoutB): - if layoutB is None: - return make_layout(Layout(1,0), layoutA) - elif is_tuple(layoutB): - assert len(layoutA) >= len(layoutB) - # A layout with shape ((A,a),(B,b),(C,c)) - split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) - # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) - return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), - make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), - (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) - - # splitter must return a rank-2 layout - return splitter(layoutA, layoutB) - - -# Apply logical divide hierarchically and gather the split modes into two modes -def zipped_divide(layoutA, layoutB): - return hier_unzip(logical_divide, layoutA, layoutB) - - -# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode -def tiled_divide(layoutA, layoutB): - result = zipped_divide(layoutA, layoutB) - return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) - - -# Apply logical product hierarchically and gather the split modes into two modes -def zipped_product(layoutA, layoutB): - return hier_unzip(logical_product, layoutA, layoutB) - - -# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode -def tiled_product(layoutA, layoutB): - result = zipped_product(layoutA, layoutB) - return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) - - -def slice_and_offset(crd: tuple, - layout: Layout): - return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), - crd2idx(crd, layout.shape, layout.stride)) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py deleted file mode 100644 index 308aee0c3838a82c4de53833fb8a36950b30f62d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py +++ /dev/null @@ -1,129 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Methods for layout swizzling -""" - -from .layout import * - - -def shiftr(a, s): - return a >> s if s > 0 else shiftl(a, -s) - - -def shiftl(a, s): - return a << s if s > 0 else shiftr(a, -s) - - -## A generic Swizzle functor - # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx - # ^--^ Base is the number of least-sig bits to keep constant - # ^-^ ^-^ Bits is the number of bits in the mask - # ^---------^ Shift is the distance to shift the YYY mask - # (pos shifts YYY to the right, neg shifts YYY to the left) - # - # e.g. Given - # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx - # the result is - # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY - # -class Swizzle: - def __init__(self, bits, base, shift): - assert bits >= 0 - assert base >= 0 - assert abs(shift) >= bits - self.bits = bits - self.base = base - self.shift = shift - bit_msk = (1 << bits) - 1 - self.yyy_msk = bit_msk << (base + max(0,shift)) - self.zzz_msk = bit_msk << (base - min(0,shift)) - - # operator () (transform integer) - def __call__(self, offset): - return offset ^ shiftr(offset & self.yyy_msk, self.shift) - - # Size of the domain - def size(self): - return 1 << (self.bits + self.base + abs(self.shift)) - - # Size of the codomain - def cosize(self): - return self.size() - - # print and str - def __str__(self): - return f"SW_{self.bits}_{self.base}_{self.shift}" - - # error msgs and representation - def __repr__(self): - return f"Swizzle({self.bits},{self.base},{self.shift})" - - -class ComposedLayout(LayoutBase): - def __init__(self, layoutB, offset, layoutA): - self.layoutB = layoutB - self.offset = offset - self.layoutA = layoutA - - # operator == - def __eq__(self, other): - return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA - - # operator len(L) (len [rank] like tuples) - def __len__(self): - return len(self.layoutA) - - # operator () (map coord to idx) - def __call__(self, *args): - return self.layoutB(self.offset + self.layoutA(*args)) - - # operator [] (get-i like tuples) - def __getitem__(self, i): - return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) - - # size(layout) Size of the domain - def size(self): - return size(self.layoutA) - - # cosize(layout) Size of the codomain - def cosize(self): - return cosize(self.layoutB) - - # print and str - def __str__(self): - return f"{self.layoutB} o {self.offset} o {self.layoutA}" - - # error msgs and representation - def __repr__(self): - return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py deleted file mode 100644 index 834f7e5411f5c2a4e218f9ce8a4f0a229d039710..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py +++ /dev/null @@ -1,42 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from abc import ABC - - -class Integer(ABC): - @classmethod - def __subclasshook__(cls, c): - if c in [bool, float]: - return False - - return issubclass(c, int) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py deleted file mode 100644 index acc0c46e540735443a4943908852010a80d02187..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py +++ /dev/null @@ -1,74 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - - -import copy -import os -import setuptools -from setuptools import setup -from setuptools.command.build_ext import build_ext - -import setup_pycute -import setup_library - - -# Install cutlass_library package -setup_library.perform_setup() - - -# Install the PyCuTe package -setup_pycute.perform_setup() - - -setup( - name='cutlass_cppgen', - version='4.2.0', - description='CUTLASS Pythonic Interface', - package_dir={'': '.'}, - packages=[ - 'cutlass_cppgen', - 'cutlass_cppgen.emit', - 'cutlass_cppgen.op', - 'cutlass_cppgen.utils', - 'cutlass_cppgen.backend', - 'cutlass_cppgen.backend.utils' - ], - setup_requires=['pybind11'], - install_requires=[ - 'bfloat16', - 'cuda-python>=11.8.0', - 'pybind11', - 'scikit-build', - 'treelib', - 'pydot' - ] -) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_library.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_library.py deleted file mode 100644 index c56d6b5556fea2d5e56209b13f5b95e487ca22fb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_library.py +++ /dev/null @@ -1,46 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from setuptools import setup - - -def perform_setup(): - setup( - name='cutlass_library', - version='4.2.1', - description='CUTLASS library generation scripts', - packages=['cutlass_library'] - ) - - -if __name__ == '__main__': - perform_setup() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py deleted file mode 100644 index 0bad050fcade8b26d33043abbb0f8226be7d816c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py +++ /dev/null @@ -1,46 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from setuptools import setup - - -def perform_setup(): - setup( - name='pycute', - version='4.2.1', - description='Python implementation of CuTe', - packages=['pycute'], - ) - - -if __name__ == '__main__': - perform_setup() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py deleted file mode 100644 index 852c0277ebae2fce7e0b083ce2f497a2c828256f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py +++ /dev/null @@ -1,661 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for defining Conv2D problem sizes for testing. - -This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h -""" - -from cutlass_library import ConvMode - -import cutlass_cppgen -from cutlass_cppgen.shape import Conv2DProblemSize - - -class TestbedConv2dProblemSizes: - def __init__(self, minimum_channel_size: int): - conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size) - conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size) - conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1) - conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34) - grouped_sizes = self.initialize_conv2d_grouped_sizes() - - # Filter all problems - self.all = [] - for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]: - for size in size_list: - if (size.C // size.groups) % minimum_channel_size == 0: - self.all.append(size) - - - def initialize_conv2d_default_sizes(self, minimum_channel_size): - # Small input size x stride (1,1) - # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - - conv2d_default_sizes = [] - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 1, 1, minimum_channel_size, - 8, 1, 1, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 1, 8, minimum_channel_size, - 8, 1, 3, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 7, 8, minimum_channel_size, - 8, 3, 3, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 7, 9, minimum_channel_size, - 8, 4, 4, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 2, 7, 9, minimum_channel_size, - 8, 5, 5, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 3, 7, 9, minimum_channel_size, - 8, 6, 5, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 3, 7, 9, minimum_channel_size, - 8, 6, 6, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 3, 7, 9, minimum_channel_size, - 8, 7, 7, minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - ############################################## - # Small input size x stride (2,2) - # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - ############################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 11, 7, minimum_channel_size, - 8, 1, 1, minimum_channel_size, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 11, 7, minimum_channel_size, - 8, 3, 3, minimum_channel_size, - 1, 1, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 13, 11, minimum_channel_size, - 8, 1, 1, minimum_channel_size, - 1, 1, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 17, 19, minimum_channel_size, - 16, 2, 2, minimum_channel_size, - 1, 1, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 23, 5, minimum_channel_size, - 16, 3, 3, minimum_channel_size, - 1, 1, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 13, 17, 8, - 24, 3, 3, 8, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 23, 21, 8, - 24, 3, 3, 8, - 1, 1, - 3, 3, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 20, 24, 8, - 40, 3, 3, 8, - 3, 3, - 3, 3, - 1, 1, - )) - - ########################################## - # Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 15, 19, 160, - 224, 1, 1, 160, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 19, 37, 160, - 224, 3, 3, 160, - 1, 1, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 16, 16, 160, - 224, 2, 3, 160, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 23, 21, 128, - 224, 3, 3, 128, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 29, 37, 160, - 224, 5, 5, 160, - 2, 2, - 1, 1, - 1, 1, - )) - - ########################################## - # C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 15, 19, 32 + minimum_channel_size, - 96, 3, 3, 32 + minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 16, 24, 64 + minimum_channel_size, - 96, 3, 3, 64 + minimum_channel_size, - 1, 1, - 1, 1, - 1, 1, - )) - - ########################################## - # Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 13, 16, 288, - 160, 5, 5, 288, - 2, 2, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 55, 51, 256, - 512, 1, 1, 256, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 71, 80, 32, - 64, 5, 5, 32, - 2, 2, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 224, 224, 8, - 64, 7, 7, 8, - 3, 3, - 2, 2, - 1, 1, - )) - - ########################################## - # Medium input size stride (3, 3), filter (3, 3), non-default padding - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 27, 23, 256, - 512, 3, 3, 256, - 0, 0, - 3, 3, - 1, 1, - )) - - ########################################## - # Medium input size padding > stride, asymmetric filter, padding and striding - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 27, 31, 256, - 512, 3, 3, 256, - 5, 7, - 3, 4, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 27, 35, 256, - 512, 7, 5, 256, - 11, 7, - 3, 5, - 1, 1, - )) - - ########################################## - # Medium input size *mixed* stride (1, 2) and (2, 1), - # filter (3, 3), default padding - ########################################## - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 27, 27, 256, - 512, 3, 3, 256, - 1, 1, - 1, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 27, 27, 256, - 512, 3, 3, 256, - 1, 1, - 2, 1, - 1, 1, - )) - - ######################################/ - # Additional input size - ######################################/ - conv2d_default_sizes.append(Conv2DProblemSize( - 3, 28, 28, 256, - 256, 2, 2, 256, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 1, 32, 32, 16, - 32, 3, 3, 16, - 1, 1, - 6, 2, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 32, 24, 32, 32, - 32, 1, 2, 32, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_default_sizes.append(Conv2DProblemSize( - 4, 2, 3, 256, - 328, 3, 5, 256, - 1, 1, - 1, 1, - 1, 1, - )) - return conv2d_default_sizes - - # Add a few large and rigorous convolution problem sizes - def initialize_conv2d_rigorous_sizes(self, minimum_channel_size): - sizes = [] - if False: - sizes.append(Conv2DProblemSize.from_sizes( - (1, 124, 224, 2 * minimum_channel_size), - (24, 7, 7, 2 * minimum_channel_size), - )) - - sizes.append(Conv2DProblemSize.from_sizes( - (1, 233, 35, minimum_channel_size), - (24, 7, 5, minimum_channel_size), - )) - return sizes - - # Add resent50 layers to unit testing sizes - def initialize_conv2d_resnet50_sizes(self, batch_size): - conv2d_problem_vector = [] - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 64, - 256, 1, 1, 64, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 64, - 64, 1, 1, 64, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 64, - 64, 3, 3, 64, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 256, - 64, 1, 1, 256, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 256, - 512, 1, 1, 256, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 56, 56, 256, - 128, 1, 1, 256, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 28, 28, 128, - 128, 3, 3, 128, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 28, 28, 128, - 512, 1, 1, 128, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 28, 28, 512, - 128, 1, 1, 512, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 28, 28, 512, - 1024, 1, 1, 512, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 28, 28, 512, - 256, 1, 1, 512, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 14, 14, 256, - 256, 3, 3, 256, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 14, 14, 256, - 1024, 1, 1, 256, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 14, 14, 1024, - 256, 1, 1, 1024, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 14, 14, 1024, - 2048, 1, 1, 1024, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 14, 14, 1024, - 512, 1, 1, 1024, - 0, 0, - 2, 2, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 7, 7, 512, - 512, 3, 3, 512, - 1, 1, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 7, 7, 512, - 2048, 1, 1, 512, - 0, 0, - 1, 1, - 1, 1, - )) - - conv2d_problem_vector.append(Conv2DProblemSize( - batch_size, 7, 7, 2048, - 512, 1, 1, 2048, - 0, 0, - 1, 1, - 1, 1, - )) - - return conv2d_problem_vector - - def initialize_conv2d_grouped_sizes(self): - threadblock_n = 128 - threadblock_k = 32 - - sizes = [] - ########################################## - # One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 - # One CTA calculates a single group - ########################################## - for cta_per_group_k in range(1, 4): - for groups in range(2, 5): - conv_k = cta_per_group_k * threadblock_n * groups - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k * 2 * groups, - conv_k, 3, 3, threadblock_k * 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - groups - )) - - # Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k, - threadblock_n * 2, 3, 3, threadblock_k // 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 2 - )) - - sizes.append(Conv2DProblemSize( - 1, 56, 56, 696, - 768, 3, 3, 232, - 1, 1, - 2, 2, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 3 - )) - sizes.append(Conv2DProblemSize( - 1, 14, 14, 1392, - 1536, 3, 3, 232, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 3 - )) - - ########################################## - # One CTA calculate multiple groups: CTA::N % k_per_group = 0 - ########################################## - - # 2 groups per CTA - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k * 4, - threadblock_n, 3, 3, threadblock_k * 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 2 - )) - - # 2 groups per CTA and partial gemm_k - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k, - threadblock_n, 3, 3, threadblock_k // 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 2 - )) - - # 4 groups per CTA - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k * 8, - threadblock_n // 2, 3, 3, threadblock_k * 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 4 - )) - - # 4 groups per CTA and partial gemm_k - sizes.append(Conv2DProblemSize( - 1, 8, 8, threadblock_k * 2, - threadblock_n // 2, 3, 3, threadblock_k // 2, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, - 4 - )) - - return sizes diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py deleted file mode 100644 index f77a0ec831be087bd3badc929eee955f0b37c489..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py +++ /dev/null @@ -1,146 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for Conv2d opreations on SM80 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from conv2d_test_utils import * - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.') -class Conv2dSm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -conv_problems = get_conv_problems() - - -# Tests for optimized & analytic -for conv_kind in ["fprop", "wgrad", "dgrad"]: - # F16, simt - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="simt", threadblock_shape=[128, 128, 8], - warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) - # F16, tensor op - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) - # F16, tensor op, analytic iterator - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") - # F16, tensor op, f32 output - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) - # F16, tensor op, different tile description - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 64, 32], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) - # F32, simt - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, - opclass="simt", threadblock_shape=[128, 128, 8], - warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) - # Tf32, tensorop - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, - opclass="tensor_op", threadblock_shape=[128, 128, 16], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] - ) - # Split-K - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", - split_k_slices=2) - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", - split_k_slices=5) - # Swizzling functor - add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 64, 32], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) - -# Tests for few channels and fixed channels -# F16, tensor op, few channels -for c, tb, stage, inst in zip([2, 1], - [[128, 128, 64], [128, 128, 32]], - [3, 2], - [[16, 8, 16], [16, 8, 8]]): - add_test( - Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=tb, - warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" - ) -# F16, tensor op, fixed channels -for c in [8, 4, 2]: - add_test( - Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" - ) - -# Test activations -for activation in ["relu", "leaky_relu"]: - for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): - add_test( - Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, - opclass="tensor_op", threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, - split_k_slices=split_k_slices, activation=activation) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py deleted file mode 100644 index 9bc4542cd5ccf72341f7db3c7947d481b032926d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py +++ /dev/null @@ -1,428 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility functions for Conv2d tests. -""" - -from cutlass_library import SubstituteTemplate -import torch - -import cutlass_cppgen -from cutlass_library import ( - ConvKind, - ConvMode, - DataType, - DataTypeNames, - EpilogueScheduleSuffixes, - KernelScheduleSuffixes, - LayoutType, - OpcodeClassNames, - ShortDataTypeNames, - ShortLayoutTypeNames, - SplitKMode, -) -from cutlass_cppgen.shape import Conv2DProblemSize -from cutlass_cppgen.utils.datatypes import numpy_type, torch_type - -from conv2d_problem_sizes import TestbedConv2dProblemSizes - - -def get_name_conv2d( - arch, - conv_kind, - element, - element_accumulator, - element_output, - opclass, - threadblock_shape, - warp_count, - instruction_shape, - stages, - iterator_algorithm, - swizzle, - split_k_mode, - split_k_slices, - activation -): - """ - Generates a procedural name for a test case for conv2d - - :param arch: compute capability of kernel being generated - :type arch: int - :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad) - :type conv_kind: str - :param iterator_algorithm: the iterator algorithm applied - :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm - :param element_a: data type of operand A - :param element_b: data type of operand B - :param element_c: data type of operand C - :param element_accumulator: data type used in accumulation - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass_cppgen.OpcodeClass - :param threadblock_shape: indexable container of dimensions of threadblock tiles - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param stride_support: stride support of dgrad - :param alignment: int - :type alignment: int - - :return: str - """ - if iterator_algorithm is None: - iterator_algorithm = "AUTO" - if swizzle is None: - swizzle = 1 - name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}" - - return SubstituteTemplate( - name_format, - { - "arch": str(arch), - "conv_kind": conv_kind, - "iter_alg": iterator_algorithm, - "eA": DataTypeNames[element], - "eB": DataTypeNames[element], - "eC": DataTypeNames[element_output], - "opclass": opclass, - "acc": DataTypeNames[element_accumulator], - "tbM": str(threadblock_shape[0]), - "tbN": str(threadblock_shape[1]), - "tbK": str(threadblock_shape[2]), - "wM": str(threadblock_shape[0] // warp_count[0]), - "wN": str(threadblock_shape[1] // warp_count[1]), - "wK": str(threadblock_shape[2] // warp_count[2]), - "IM": str(instruction_shape[0]), - "IN": str(instruction_shape[1]), - "IK": str(instruction_shape[2]), - "stages": str(stages), - "swizzle": str(swizzle), - "split_k_mode": split_k_mode, - "split_k_slices": str(split_k_slices), - "activation": activation - } - ) - - -def conv2d_few_channel_problemsizes(channels): - problem_sizes = [ - Conv2DProblemSize( - 1, 8, 8, channels, - 16, 3, 3, channels, - 1, 1, - 2, 2, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 16, 16, channels, - 16, 3, 3, channels, - 1, 1, - 2, 2, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 16, 16, channels, - 16, 7, 7, channels, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 224, 224, channels, - 32, 7, 7, channels, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 224, 224, channels, - 64, 7, 7, channels, - 1, 1, - 2, 2, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 224, 224, channels, - 64, 5, 5, channels, - 1, 1, - 1, 1, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 224, 224, channels, - 64, 5, 5, channels, - 1, 1, - 2, 2, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - ] - - return problem_sizes - - -def validate_problem_size(ps, conv_kind, split_k_slices): - P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1 - Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1 - if P != ps.P or Q != ps.Q: - return False - - # Split-K (serial or parallel) is not supported for strided dgrad - if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1): - return False - return True - - -class Conv2dLauncherFrontend: - def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"): - self.operation = plan - self.conv_kind = plan.conv_kind - self.seed = seed - self.backend = backend - - self.dtype_A = plan._element_a - self.dtype_B = plan._element_b - self.dtype_C = plan._element_c - self.dtype_acc = plan._element_accumulator - self.layout_A = LayoutType.TensorNHWC - self.layout_B = LayoutType.TensorNHWC - self.layout_C = LayoutType.TensorNHWC - self.layout_D = LayoutType.TensorNHWC - - self.element_compute = DataType.f32 - - if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]: - self.rand_max = 1 - else: - self.rand_max = 4 - self.activation = plan.activation - - def uniform_init(self, size, dtype): - tensor = torch.ceil( - torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5) - ).to(memory_format=torch.channels_last) - return tensor - - def reference(self, ps, A, B, C, alpha, beta, activation): - if self.conv_kind == ConvKind.Fprop: - torch_result = alpha * torch.ops.aten.conv2d( - A, - B, - stride=(ps.stride_h, ps.stride_w), - padding=(ps.pad_h, ps.pad_w), - dilation=(ps.dilation_h, ps.dilation_w) - ) + beta * C - elif self.conv_kind == ConvKind.Dgrad: - torch_result = alpha * torch.nn.grad.conv2d_input( - (ps.N, ps.C, ps.H, ps.W), - B, - A, - padding=(ps.pad_h, ps.pad_w), - stride=(ps.stride_h, ps.stride_w) - ) + beta * C - elif self.conv_kind == ConvKind.Wgrad: - torch_result = alpha * torch.nn.grad.conv2d_weight( - B, - (ps.K, ps.C, ps.R, ps.S), - A, - padding=(ps.pad_h, ps.pad_w), - stride=(ps.stride_h, ps.stride_w) - ) + beta * C - else: - raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") - - if activation == cutlass_cppgen.backend.epilogue.relu: - torch_result = torch.nn.functional.relu(torch_result) - elif activation == cutlass_cppgen.backend.epilogue.leaky_relu: - torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) - return torch_result - - def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0): - if self.conv_kind == ConvKind.Fprop: - tensor_A_size = (ps.N, ps.C, ps.H, ps.W) - tensor_B_size = (ps.K, ps.C, ps.R, ps.S) - tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) - elif self.conv_kind == ConvKind.Dgrad: - tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) - tensor_B_size = (ps.K, ps.C, ps.R, ps.S) - tensor_C_size = (ps.N, ps.C, ps.H, ps.W) - elif self.conv_kind == ConvKind.Wgrad: - tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) - tensor_B_size = (ps.N, ps.C, ps.H, ps.W) - tensor_C_size = (ps.K, ps.C, ps.R, ps.S) - else: - raise Exception(f"Conv kind {self.conv_kind} is not supported") - - torch.manual_seed(self.seed) - - tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A) - tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) - tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) - tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last) - args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, - stride=(ps.stride_h, ps.stride_w), - padding=(ps.pad_h, ps.pad_w), - dilation=(ps.dilation_h, ps.dilation_w), - alpha=alpha, beta=beta, - split_k=(split_k_mode, split_k_slices)) - - args.sync() - - tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation) - - torch.cuda.synchronize() - passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06) - - return passed - - -def add_test( - cls, - cc, - conv_kind, - problem_sizes, - element, - element_accumulator, - element_output, - opclass, - threadblock_shape, - warp_count, - instruction_shape, - stages, - iterator_algorithm=None, - swizzle=None, - split_k_mode="serial", - split_k_slices=1, - activation = "identity" -): - """Create a test-running function with the given specification""" - test_name = get_name_conv2d( - cc, conv_kind, element, element_accumulator, - element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages, - iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation) - - def run(self): - # Create the plan - plan = cutlass_cppgen.Conv2d( - kind=conv_kind, - element=element, - element_accumulator=element_accumulator, - element_C=element_output, - element_D=element_output - ) - - # Set the opclass - plan.opclass = opclass - # Set the tile description - td = { - "threadblock_shape": threadblock_shape, - "warp_count": warp_count, - "stages": stages, - "instruction_shape": instruction_shape, - } - - plan.tile_description = td - # Set iterator algorithm - if iterator_algorithm is not None: - plan.iterator_algorithm = iterator_algorithm - # Set swizzling functor - if swizzle is not None: - plan.swizzling_stride = swizzle - - if activation != "identity": - if activation == "leaky_relu": - plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5) - else: - plan.activation = getattr(cutlass_cppgen.epilogue, activation) - - conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") - - for ps in problem_sizes: - if not validate_problem_size(ps, conv_kind, split_k_slices): - continue - - self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0)) - - setattr(cls, test_name, run) - - return run - - -def get_conv_problems(): - # 64: minimum channel size - conv_problems = TestbedConv2dProblemSizes(64).all - - # Insert alignment 4 & 2 tests - conv_problems += [ - Conv2DProblemSize( - 1, 4, 4, 12, - 8, 3, 3, 12, - 0, 0, - 3, 3, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 4, 4, 14, - 8, 3, 3, 14, - 0, 0, - 3, 3, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - Conv2DProblemSize( - 1, 23, 56, 98, - 128, 3, 3, 98, - 4, 5, - 3, 3, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ), - ] - - return conv_problems diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py deleted file mode 100644 index d892b5df047d5121345d902a77aadf2256b4c3b5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py +++ /dev/null @@ -1,44 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import pathlib -import unittest - - -if __name__ == '__main__': - loader = unittest.TestLoader() - script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' - tests = loader.discover(script_dir, 'conv2d_*.py') - testRunner = unittest.runner.TextTestRunner() - results = testRunner.run(tests) - if not results.wasSuccessful(): - raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py deleted file mode 100644 index c9d4c52a9f75fb4c3bc809947bf48ba85356ec70..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py +++ /dev/null @@ -1,309 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Tests emitting a CUTLASS kernel to a PyTorch CUDA extension -""" - -import random -import tempfile -import unittest - -from cutlass_library import ConvMode - -import cutlass_cppgen - -if cutlass_cppgen.utils.datatypes.is_torch_available(): - import torch - - -def _initialize(dtype, M: int, N: int, K: int): - """ - Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K - - :param dtype: data type of tensors - :param M: M dimension of GEMM problem - :type M: int - :param N: N dimension of GEMM problem - :type N: int - :param K: N dimension of GEMM problem - :type K: int - - :return: initialized tensors A, B, C, and D - :rtype: list - """ - sizes = [(M, K), (K, N), (M, N), (M, N)] - return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes] - - -def _generate_problems(dtype, num): - """ - Utility function to generate `num` GEMMs of random sizes - - :param dtype: data type of tensors - :param num: number of GEMMs to generate - :type num: int - - :return: lists of A, B, C, and D tensors - :rtype: list - """ - valid_sizes = [128, 256, 512, 1024] - As, Bs, Cs, Ds = [], [], [], [] - for _ in range(num): - M, N, K = [random.choice(valid_sizes) for _ in range(3)] - A, B, C, D = _initialize(dtype, M, N, K) - As.append(A) - Bs.append(B) - Cs.append(C) - Ds.append(D) - return As, Bs, Cs, Ds - -def _generate_conv2d_problem(conv_kind, dtype, ps): - """ - Utility function to generate conv2d inputs - - :param conv_kind: kind of convolution - :type conv_kind: str - :param dtype: data type of tensors - :param problem_size: the conv2d problem size - :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize - - :return: initialized tensors A, B, C, and D - :rtype: list - """ - if conv_kind == "fprop": - tensor_A_size = (ps.N, ps.C, ps.H, ps.W) - tensor_B_size = (ps.K, ps.C, ps.R, ps.S) - tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) - elif conv_kind == "dgrad": - tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) - tensor_B_size = (ps.K, ps.C, ps.R, ps.S) - tensor_C_size = (ps.N, ps.C, ps.H, ps.W) - else: - tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) - tensor_B_size = (ps.N, ps.C, ps.H, ps.W) - tensor_C_size = (ps.K, ps.C, ps.R, ps.S) - sizes = [tensor_A_size, tensor_B_size, tensor_C_size] - return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] - - -@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') -class PyTorchExtensionTest(unittest.TestCase): - - def test_gemm(self): - random.seed(2023) - - dtype = torch.float16 - plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) - op = plan.construct() - - with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) - - A, B, C, _ = _initialize(dtype, 1024, 256, 512) - - D_ref = A @ B - D = mod.run(A, B) - assert torch.allclose(D, D_ref) - - D = mod.run(A, B, C) - assert torch.allclose(D, D_ref) - - D = mod.run(A, B, C, 1.0) - assert torch.allclose(D, D_ref) - - D = mod.run(A, B, C, 1.0, 0.0) - assert torch.allclose(D, D_ref) - - alpha = 2.0 - beta = -1.0 - D_ref = (A @ B) * alpha + (beta * C) - D = mod.run(A, B, C, alpha, beta) - assert torch.allclose(D, D_ref) - - def test_grouped_gemm(self): - random.seed(2023) - - dtype = torch.float16 - plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) - op = plan.construct() - - with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) - - As, Bs, Cs, _ = _generate_problems(dtype, 50) - - def check_all(X, Y): - for x, y in zip(X, Y): - assert torch.allclose(x, y) - - Ds_ref = [a @ b for a, b in zip(As, Bs)] - Ds = mod.run(As, Bs) - check_all(Ds, Ds_ref) - - Ds = mod.run(As, Bs, Cs) - check_all(Ds, Ds_ref) - - Ds = mod.run(As, Bs, Cs, 1.0) - check_all(Ds, Ds_ref) - - Ds = mod.run(As, Bs, Cs, 1.0, 0.0) - check_all(Ds, Ds_ref) - - alpha = 2.0 - beta = -1.0 - Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] - Ds = mod.run(As, Bs, Cs, alpha, beta) - check_all(Ds, Ds_ref) - - def test_conv2d_fprop(self): - torch.manual_seed(2023) - - dtype = torch.float16 - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) - plan.activation = "relu" - - op = plan.construct() - with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - - problem_size = cutlass_cppgen.shape.Conv2DProblemSize( - 1, 4, 4, 16, - 8, 3, 3, 16, - 0, 0, - 3, 3, - 1, 1 - ) - - A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) - stride = (problem_size.stride_h, problem_size.stride_w) - padding = (problem_size.pad_h, problem_size.pad_w) - - alpha = 1.0 - beta = 0.5 - - D_ref = alpha * torch.ops.aten.conv2d( - A, B, stride=stride, padding=padding - ) + beta * C - D_ref = torch.nn.functional.relu(D_ref) - D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) - - assert torch.allclose(D, D_ref) - - # Test serial split-K - D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) - assert torch.allclose(D, D_serial_split_k) - - # Test parallel split-K - D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) - assert torch.allclose(D, D_parallel_split_k) - - - def test_conv2d_dgrad(self): - torch.manual_seed(2023) - dtype = torch.float16 - plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) - - op = plan.construct() - with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - - problem_size = cutlass_cppgen.shape.Conv2DProblemSize( - 1, 4, 4, 16, - 8, 3, 3, 16, - 0, 0, - 3, 3, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ) - - A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) - stride = (problem_size.stride_h, problem_size.stride_w) - padding = (problem_size.pad_h, problem_size.pad_w) - - alpha = 1.0 - beta = 0.5 - input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) - D_ref = alpha * torch.nn.grad.conv2d_input( - input_size, B, A, - stride=stride, padding=padding - ) + beta * C - D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) - - assert torch.allclose(D, D_ref) - - def test_conv2d_wgrad(self): - torch.manual_seed(2023) - dtype = torch.float16 - plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) - - op = plan.construct() - with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - - problem_size = cutlass_cppgen.shape.Conv2DProblemSize( - 1, 4, 4, 16, - 8, 3, 3, 16, - 0, 0, - 3, 3, - 1, 1, - ConvMode.CrossCorrelation, - 1, 1 - ) - - A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) - stride = (problem_size.stride_h, problem_size.stride_w) - padding = (problem_size.pad_h, problem_size.pad_w) - - alpha = 1.0 - beta = 0.5 - weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) - D_ref = alpha * torch.nn.grad.conv2d_weight( - B, weight_size, A, - stride=stride, padding=padding - ) + beta * C - D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) - - assert torch.allclose(D, D_ref) - - # Test serial split-K - D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) - assert torch.allclose(D, D_serial_split_k) - - # Test parallel split-K - D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) - assert torch.allclose(D, D_parallel_split_k) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py deleted file mode 100644 index 5467469e74e05573fb297b009914e0980e5ab222..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py +++ /dev/null @@ -1,198 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ -""" -Unit test for compute node in SM90 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend import * -from cutlass_cppgen.epilogue import * -from cutlass_cppgen import swizzle - -from utils.evt_testbed import EVTTestBed, EVTTestCaseBase - -cutlass_cppgen.set_log_level(logging.WARNING) - - -@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") -class TestEVTCompute(EVTTestCaseBase): - - def test_arith(self): - """ - Test Arithmatic op - """ - def evt_arith_compute(accum, C, alpha, beta, gamma): - D = ((accum + C) * alpha - gamma) / beta - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.5, - "beta": 0.5, - "gamma": 2.5, - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs) - input_keys = ["C", "alpha", "beta", "gamma"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_func_call(self): - """ - Test Function call - """ - def evt_func_call(accum, C, alpha, beta, gamma): - D = multiply_add(relu(accum + alpha) + C, beta, gamma) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.5, - "beta": 0.5, - "gamma": 2.5, - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_func_call, example_inputs) - input_keys = ["C", "alpha", "beta", "gamma"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_func_call2(self): - """ - Test Function call - """ - - def evt_func_call2(accum, C, alpha, beta): - D = maximum(alpha * accum + beta * C, 0.0) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.5, - "beta": 0.5, - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) - input_keys = ["C", "alpha", "beta"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_tanh(self): - """ - Test Tanh op - """ - def evt_tanh(accum): - D = tanh(accum) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_tanh, example_inputs) - input_keys = [] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_sigmoid(self): - """ - Test Sigmoid op - """ - def evt_sigmoid(accum): - D = sigmoid(accum) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs) - input_keys = [] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_gelu(self): - """ - Test GELU op - """ - def evt_gelu(accum): - D = gelu(accum) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_gelu, example_inputs) - input_keys = [] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_exp(self): - """ - Test Exp op - """ - def evt_exp(accum): - D = exp(accum) - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)) - } - - launcher = EVTTestBed(self.element, evt_exp, example_inputs) - input_keys = [] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py deleted file mode 100644 index f5a7b7f7a336dce0651f299d26b17df04952be99..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +++ /dev/null @@ -1,173 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -""" -Unit test for store nodes in SM90 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend import * -from cutlass_cppgen.epilogue import * - -from utils.evt_testbed import EVTTestBed, EVTTestCaseBase - -cutlass_cppgen.set_log_level(logging.WARNING) - - -@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") -class TestEVTLayout(EVTTestCaseBase): - - def test_permute_1(self): - """ - Returning a tensor with shape [m, n] - """ - def evt_permute(accum, alpha, C): - F = alpha * accum - F_permute = permute(F, indices=(0, 2, 1)) - D_permute = F_permute + permute(C, indices=(0, 2, 1)) - D = permute(D_permute, indices=(0, 2, 1)) - return D, F - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 0.5, - "C": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_permute, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") - def test_permute_2(self): - """ - Returning a tensor with shape [m, n] - """ - def evt_permute(accum, alpha, C): - F = alpha * accum - F_permute = permute(F, indices=(0, 2, 1)) - D = F_permute + C - return D, F - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 0.5, - "C": self.fake_tensor(self.element, (l, n, m)), - "F": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, n, m)), - } - - launcher = EVTTestBed(self.element, evt_permute, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") - def test_permute_3(self): - """ - Returning a tensor with shape [m, n] - """ - def evt_permute(accum, alpha, C): - F = alpha * accum - F_permute = permute(F, indices=(1, 0, 2)) - D = F_permute + C - return D, F - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 0.5, - "C": self.fake_tensor(self.element, (m, l, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (m, l, n)), - } - - launcher = EVTTestBed(self.element, evt_permute, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_reshape(self): - """ - Test reshape - """ - def evt_reshape(accum, alpha, TensorE): - F = alpha * accum - E_reshape = reshape(TensorE, new_shape=(512, 1)) - D = F + E_reshape - return D - - example_inputs = { - "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), - "alpha": 0.5, - "TensorE": self.fake_tensor(self.element, (16, 32)), - "D": self.fake_tensor(self.element, (self.l, self.m, self.n)), - } - - launcher = EVTTestBed(self.element, evt_reshape, example_inputs) - input_keys = ["alpha", "TensorE"] - result_keys = ["D"] - launcher.verify(self.problem_size, input_keys, result_keys, self.l) - - def test_reshape2(self): - """ - Test reshape - """ - def evt_reshape(accum, alpha, TensorE): - F = alpha * accum - F_reshape = reshape(F, new_shape=(2, 3, 512, 256)) - D = F_reshape + TensorE - return D - - example_inputs = { - "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), - "alpha": 0.5, - "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)), - "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)), - } - - launcher = EVTTestBed(self.element, evt_reshape, example_inputs) - input_keys = ["alpha", "TensorE"] - result_keys = ["D"] - launcher.verify(self.problem_size, input_keys, result_keys, self.l) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py deleted file mode 100644 index 57a5c6bb17bb44bf294cc7a6a749c706601034f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +++ /dev/null @@ -1,142 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -""" -Unit test for load nodes in SM90 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend import * -from cutlass_cppgen.epilogue import * - -from utils.evt_testbed import EVTTestBed, EVTTestCaseBase - -cutlass_cppgen.set_log_level(logging.WARNING) - - -@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") -class TestEVTLoad(EVTTestCaseBase): - - def test_tensor_load(self): - """ - Load extra tensor with shape [m, n] - """ - def evt_tensor_load(accum, C, aux, aux_batch): - D = accum + C + aux + aux_batch - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "aux": self.fake_tensor(self.element, (m, n)), - "aux_batch": self.fake_tensor(np.float32, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs) - input_keys = ["C", "aux", "aux_batch"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_row_broadcast(self): - """ - Load extra tensor with shape [1, n] - """ - def evt_row_broadcast(accum, C, bias, bias_batch): - D = accum + C + bias + bias_batch - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "bias": self.fake_tensor(self.element, (n,)), - "bias_batch": self.fake_tensor(np.float32, (l, 1, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs) - input_keys = ["C", "bias", "bias_batch"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_column_broadcast(self): - """ - Load extra tensor with shape [m, 1] - """ - def evt_column_broadcast(accum, C, bias, bias_batch): - D = accum + C + bias + bias_batch - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "bias": self.fake_tensor(self.element, (m, 1)), - "bias_batch": self.fake_tensor(np.float32, (l, m, 1)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs) - input_keys = ["C", "bias", "bias_batch"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_scalar_broadcast(self): - """ - Load extra tensor with shape [1, 1] - """ - def evt_scalar_broadcast(accum, C, alpha, alpha_batch): - D = accum + C + alpha + alpha_batch - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "C": self.fake_tensor(self.element, (l, m, n)), - "alpha": 0.5, - "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs) - input_keys = ["C", "alpha", "alpha_batch"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py deleted file mode 100644 index 30dc8fe0d5ec413f1da57a8fa0875ed5e7baa887..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +++ /dev/null @@ -1,319 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -""" -Unittest for mixed types of nodes in SM90 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend import * -from cutlass_cppgen.epilogue import * -from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK - -from utils.evt_testbed import EVTTestBed, EVTTestCaseBase - -cutlass_cppgen.set_log_level(logging.WARNING) - - -@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") -class TestEVTMixed(EVTTestCaseBase): - - def test_same_variable_used_multiple_times(self): - """ - The same variable z0 is used multiple times - """ - def evt_aux_store(accum): - z0 = relu(accum) - D = z0 + z0 - return z0, D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - "z0": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) - input_keys = ["accum"] - result_keys = ["z0", "D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_no_lca(self): - """ - The same variable z0 is used multiple times - """ - def evt_no_lca(accum, bias): - E = relu(accum) - F = E + bias - tmp_2 = E + 2 - D = tmp_2 + E - return D - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), - } - - launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) - input_keys = ["accum", "bias"] - result_keys = ["D"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_mixed_dag(self): - def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - if device_cc() == 80: - alignments = [2, 4, 8] - else: - # Sm90 EVT currently only supports 128-bit alignment - alignments = [8,] - for align in alignments: - for m, n, k, l in self.get_problem_sizes(align): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (l, m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - - launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs) - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") - def test_mixed_dag_float(self): - def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - for align in [3, 2, 4]: - for m, n, k, l in self.get_problem_sizes(align): - example_inputs = { - "accum": self.fake_tensor(np.float32, (l, m, n)), - "alpha": 1.0, - "C": self.fake_tensor(np.float32, (l, m, n)), - "beta": 1.0, - "aux": self.fake_tensor(np.float32, (l, m, n)), - "cbias": self.fake_tensor(np.float32, (m, 1)), - "rbias": self.fake_tensor(np.float32, (n,)), - "D": self.fake_tensor(np.float32, (l, m, n)), - "F": self.fake_tensor(np.float32, (l, m, n)), - "F_row_max": self.fake_tensor(np.float32, (n,)), - "E_col_max": self.fake_tensor(np.float32, (m, 1)) - } - launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs) - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") - def test_mixed_dag_stage2(self): - def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (l, m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - - launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2) - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") - def test_mixed_dag_partition_k(self): - def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (l, m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - - tile_description = { - "threadblock_shape": [128, 128, 64], - "warp_count": [2, 2, 2] - } - - launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2) - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") - def test_mixed_dag_stream_k(self): - def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - # High per-sm occupancy tile_description - tile_description = { - "threadblock_shape": [128, 128, 32], - "warp_count": [2, 2, 1], - "stages": 3 - } - tds = [None, tile_description] - for td in tds: - for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]): - if l == 1: - example_inputs = { - "accum": self.fake_tensor(self.element, (m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (m, n)), - "F": self.fake_tensor(self.element, (m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - else: - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (l, m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - - if td is not None: - launcher = EVTTestBed( - self.element, evt_mixed_dag, example_inputs, - tile_description=td, - swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") - else: - launcher = EVTTestBed( - self.element, evt_mixed_dag, example_inputs, - swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") - - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_mixed_dag_no_batch(self): - def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias): - F = alpha * accum + (beta * C + aux) - F_row_max = max(F, dim=[0, 1]) - E = relu(F + 1) + cbias + rbias - E_col_max = max(E, dim=[0, 2]) - D = E + F - return D, F, F_row_max, E_col_max - - for m, n, k, _ in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (m, n)), - "alpha": 1.0, - "C": self.fake_tensor(self.element, (m, n)), - "beta": 1.0, - "aux": self.fake_tensor(self.element, (m, n)), - "cbias": self.fake_tensor(self.element, (m, 1)), - "rbias": self.fake_tensor(self.element, (n,)), - "D": self.fake_tensor(self.element, (m, n)), - "F": self.fake_tensor(self.element, (m, n)), - "F_row_max": self.fake_tensor(DataType.f32, (n,)), - "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) - } - - launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs) - input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] - result_keys = ["D", "F", "F_row_max", "E_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, 1) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py deleted file mode 100644 index b47f11e4f3bde3499948ae68b1b5bb79347f0fd1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +++ /dev/null @@ -1,180 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -""" -Unit test for store nodes in SM90 -""" - -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend import * -from cutlass_cppgen.epilogue import * - -from utils.evt_testbed import EVTTestBed, EVTTestCaseBase - -cutlass_cppgen.set_log_level(logging.WARNING) - - -@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") -class TestEVTStore(EVTTestCaseBase): - - @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") - def test_invalid_store(self): - """ - Test invalid store - """ - def evt_invalid_store(accum): - D = accum - F = D + 1 # D has users, which is not allowed on SM90 or higher - return D, F - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)) - } - with self.assertRaisesRegex( - RuntimeError, - r"On SM90 or higher, D is expected to be a output node with 0 users " - r"to enable smem reuse between C and D, but got 1" - ): - launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) - - break # Only need to test once - - def test_aux_store(self): - """ - Returning a tensor with shape [m, n] - """ - def evt_aux_store(accum, alpha, C): - F = alpha * accum - D = F + C - return D, F - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 0.5, - "C": self.fake_tensor(self.element, (l, m, n)), - "F": self.fake_tensor(self.element, (l, m, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_col_reduce(self): - """ - Reduction [m, n] -> [m, 1] - """ - def evt_row_reduce(accum, alpha, C): - acc_row_max = max(accum, dim=[2,]) - F = alpha * accum - F_row_max = max(F, dim=[0, 2]) - D = F + C - return D, F_row_max, acc_row_max - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 2.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "F_row_max": self.fake_tensor(np.float32, (m, 1)), - "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F_row_max", "acc_row_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_row_reduce(self): - """ - Reduction [m, n] -> [n] - """ - def evt_col_reduce(accum, alpha, C): - acc_col_max = max(accum, dim=[1,]) - F = alpha * accum - F_col_max = max(F, dim=[0, 1]) - D = F + C - return D, F_col_max, acc_col_max - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 2.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "F_col_max": self.fake_tensor(np.float32, (n,)), - "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F_col_max", "acc_col_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - def test_scalar_reduce(self): - """ - Reduction [m, n] -> [1,] - """ - def evt_scalar_reduce(accum, alpha, C): - acc_max = max(accum, dim=[1, 2]) - F = alpha * accum - F_max = max(F, dim=[0, 1, 2]) - D = F + C - return D, F_max, acc_max - - for m, n, k, l in self.get_problem_sizes(8): - example_inputs = { - "accum": self.fake_tensor(self.element, (l, m, n)), - "alpha": 2.0, - "C": self.fake_tensor(self.element, (l, m, n)), - "acc_max": self.fake_tensor(np.float32, (l, 1, 1)), - "F_max": self.fake_tensor(np.float32, (1,)), - "D": self.fake_tensor(self.element, (l, m, n)), - } - - launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs) - input_keys = ["C", "alpha"] - result_keys = ["D", "F_max", "acc_max"] - launcher.verify((m, n, k), input_keys, result_keys, l) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py deleted file mode 100644 index 5bb84e2e8c85e602b45b9ee18ce324accd3a32cd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py +++ /dev/null @@ -1,44 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import pathlib -import unittest - - -if __name__ == '__main__': - loader = unittest.TestLoader() - script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' - tests = loader.discover(script_dir, 'evt_*.py') - testRunner = unittest.runner.TextTestRunner() - results = testRunner.run(tests) - if not results.wasSuccessful(): - raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py deleted file mode 100644 index 62d375d856ffaef6be50b39b76121e0eb78a7465..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py +++ /dev/null @@ -1,235 +0,0 @@ -################################################################################ -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################ - -""" -Testbed classes of EVT -""" - -import torch -import unittest - -import cutlass_cppgen -from cutlass_cppgen import Tensor -import cutlass_cppgen.backend.evt -from cutlass_cppgen.shape import GemmCoord -from cutlass_cppgen.utils.datatypes import torch_type -from cutlass_cppgen.utils.profiler import CUDAEventProfiler - - -class EVTReferenceModule: - def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor): - self.layout_A = layout_A - self.layout_B = layout_B - self.layout_C = layout_C - self.epilogue_visitor = epilogue_visitor - - def run(self, A, B, C, problem_size, alpha, beta, batch=1): - if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: - A_row = A.view((batch, problem_size.m, problem_size.k)) - else: - A_col = A.view((batch, problem_size.k, problem_size.m)) - A_row = torch.permute(A_col, (0, 2, 1)) - - if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: - B_row = B.view((batch, problem_size.k, problem_size.n)) - else: - B_col = B.view((batch, problem_size.n, problem_size.k)) - B_row = torch.permute(B_col, (0, 2, 1)) - - if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: - C_row = C.view((batch, problem_size.m, problem_size.n)) - else: - C_col = C.view((batch, problem_size.n, problem_size.m)) - C_row = torch.permute(C_col, (0, 2, 1)) - - out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta - - if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: - out = torch.permute(out_row, (0, 2, 1)) - else: - out = out_row - - return torch.flatten(out) - - def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None): - # Running the mainloop - accum = self.run( - A, B, C, problem_size, 1.0, 0.0, batch=batch - ).reshape(batch, problem_size.m, problem_size.n) - - # Running the epilogue - epilogue_args["accum"] = accum - references = self.epilogue_visitor(**epilogue_args) - - # Return the results - if not isinstance(references, tuple): - references = (references,) - return references - - -class EVTTestBed: - """ - Epilogue Visitor Testbed - """ - def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: - self.element = element - layout = cutlass_cppgen.LayoutType.RowMajor - self.example_inputs = example_inputs - - # Create the Gemm plan - self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) - - if "tile_description" in kwargs: - self.plan.tile_description = kwargs["tile_description"] - - if "swizzling_functor" in kwargs: - self.plan.swizzling_functor = kwargs["swizzling_functor"] - - # Compile the epilogue visitor - epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) - if "epilogue_stages" in kwargs: - epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] - self.plan.epilogue_visitor = epilogue_visitor - - # Reference model - self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor) - - self.profile = profile - - def get_torch_tensor(self, shape, dtype=None, fill=None): - if dtype is None: - dtype = self.element - - dtype = torch_type(dtype) - if fill is None: - return torch.ceil( - torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5) - ) - else: - return torch.full(shape, fill, dtype=dtype, device="cuda") - - def verify(self, problem_size, input_keys, result_keys, batch_count=1): - """ - Verify the results - """ - problem_size = GemmCoord(*problem_size) - - # Initiate the GEMM arguments - tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k)) - tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n)) - - # Initialize the epilogue args - epilogue_args = {} - for key in self.example_inputs.keys(): - if key in input_keys: - tensor = self.example_inputs[key] - if isinstance(tensor, Tensor): - epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element) - else: - epilogue_args[key] = tensor - elif key in result_keys: - tensor = self.example_inputs[key] - if isinstance(tensor, Tensor): - if "max" in key: - fill = -1000 - else: - fill = 0 - epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill) - else: - epilogue_args[key] = tensor - - tensor_D = epilogue_args["D"] - if "C" in epilogue_args: - tensor_C = epilogue_args["C"] - else: - tensor_C = tensor_D - # Run the device kernel - self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args) - - # Run the host reference - evt_args_inputs = {} - for key in input_keys: - evt_args_inputs[key] = epilogue_args[key] - - reference_results = self.reference_fn( - tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs) - - # Compare the results - for result, ref in zip(result_keys, reference_results): - assert torch.equal( - epilogue_args[result].flatten(), - ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) - - # Run profile - if self.profile: - profiler = CUDAEventProfiler( - self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D, - visitor_args = epilogue_args - ) - print(f"Cutlass Python Duration: {profiler()}") - - -class EVTTestCaseBase(unittest.TestCase): - """ - Base class for EVT Unittest - """ - def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: - super().__init__(methodName) - - self.element = cutlass_cppgen.DataType.f16 - self.l, self.m, self.n, self.k = lmnk - - self.problem_size = (self.m, self.n, self.k) - - torch.random.manual_seed(42) - - def fake_tensor(self, element, shape, stride=None): - if stride is None: - return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) - else: - return Tensor(element=element, shape=shape, stride=stride) - - def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): - k = k if k else self.k - problem_size_m = [alignment, 512 - 3 * alignment] - problem_size_n = [alignment, 512 - alignment] - if alignment % 8 == 0: - problem_size_m.append(768) - problem_size_n.append(768) - problem_size_l = batch_count - problem_sizes = [] - for m in problem_size_m: - for n in problem_size_n: - for l in problem_size_l: - problem_sizes.append((m, n, k, l)) - - return problem_sizes diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py deleted file mode 100644 index 155426ab902d1f99eafc7b03c388fc79b4520317..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py +++ /dev/null @@ -1,134 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -High-level tests for running batched GEMMs -""" - -from functools import partial -import logging -from math import prod -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc -import torch - -from utils import LayoutCombination - -cutlass_cppgen.set_log_level(logging.WARNING) - -torch.manual_seed(2023) - - -def pytorch_reference(A, B, C, alpha, beta): - # Get the batch count. Assume that any of A, B, and C - # with a batch dimension ahve matching batch count. Thus, - # we break out of the loop once we have found the first - # tensor containing a batch dimension. - batch_count = (1,) - for tensor in [A, B, C]: - if len(tensor.shape) > 2: - batch_count = tensor.shape[:-2] - break - - int_batch_count = prod(batch_count) - - def add_batch(tensor): - if len(tensor.shape) == 2: - return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) - else: - return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) - - # Reshape tensors to have batch dimension - A = add_batch(A) - B = add_batch(B) - C = add_batch(C) - - ret = (torch.bmm(A, B) * alpha) + (C * beta) - reshape_vals = batch_count + C.shape[-2:] - return ret.reshape(*reshape_vals) - - -def initialize(rows, cols, batch): - tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() - if len(batch) > 0 and prod(batch) > 1: - reshape_vals = batch + (rows, cols) - return tensor.reshape(*reshape_vals) - else: - return tensor.reshape(rows, cols) - - -class GemmF16Batched(unittest.TestCase): - def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): - M = 512 - N = 256 - K = 128 - alpha = 1. - beta = 2. - - A = initialize(M, K, batch_count if batch_A else (1,)) - B = initialize(K, N, batch_count if batch_B else (1,)) - C = initialize(M, N, batch_count if batch_C else (1,)) - D = initialize(M, N, batch_count) - - plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) - plan.run(A, B, C, D, alpha, beta) - reference = pytorch_reference(A, B, C, alpha, beta) - assert reference.equal(D) - - def test_batched_ABC(self): - self.run_batched((3,), True, True, True) - self.run_batched((2, 3), True, True, True) - - def test_batched_AB(self): - self.run_batched((3,), True, True, False) - self.run_batched((2, 3), True, True, False) - - def test_batched_AC(self): - self.run_batched((3,), True, False, True) - self.run_batched((2, 3), True, False, True) - - def test_batched_BC(self): - self.run_batched((3,), False, True, True) - self.run_batched((2, 3), False, True, True) - - def test_batched_A(self): - self.run_batched((3,), True, False, False) - self.run_batched((2, 3), True, False, False) - - def test_batched_B(self): - self.run_batched((3,), False, True, False) - self.run_batched((2, 3), False, True, False) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py deleted file mode 100644 index dbd26951ec5d8a1eb6cbe38491c64fde2873b9c3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py +++ /dev/null @@ -1,128 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with F16 operands on SM80 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 -dtype = cutlass_cppgen.DataType.f16 - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF16Sm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF16Sm80StreamK(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - -add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) - -# Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) - -# Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) - -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) - -# Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py deleted file mode 100644 index 61aa295b966daf5943e7092572c98ee20143e2b5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py +++ /dev/null @@ -1,146 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with F16 operands on SM90 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 90 -dtype = cutlass_cppgen.DataType.f16 - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF16Sm90(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, - warp_count=None, compilation_modes=['nvcc']) - -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -# Tests with 1x1x1 clusters -add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) -add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) -add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) - -# Tests with different cluster shapes -add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) - -# Tests for different schedule modes -add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], - element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, - opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) -add_test_schedule( - cluster_shape=[1, 1, 1], - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized -) -add_test_schedule( - cluster_shape=[1, 1, 1], - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative -) -add_test_schedule( - cluster_shape=[2, 1, 1], - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized -) -add_test_schedule( - cluster_shape=[2, 1, 1], - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative -) - -# Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) -add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) -add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) -add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) -add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) -add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) - -# Tests with void-C kernels -add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, - element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, - cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py deleted file mode 100644 index bf662b9208ab2a5343d0fd11106835b7d9a5b2e9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py +++ /dev/null @@ -1,104 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with F32 operands on SM80 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 -dtype = cutlass_cppgen.DataType.f32 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF32Sm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF32Sm80StreamK(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) - -# Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) -# Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) - -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) - -# Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py deleted file mode 100644 index 3075ddf74bf2a119759ca1a3e47c0815f4b0923c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py +++ /dev/null @@ -1,103 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with F64 operands on SM80 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 -dtype = cutlass_cppgen.DataType.f64 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF64Sm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF64Sm80StreamK(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) - -# Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) - -# Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) - -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) - -# Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, - element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py deleted file mode 100644 index 9bf36fc77436fef22882e98c752b7a599cf7fb95..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py +++ /dev/null @@ -1,71 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with F64 operands on SM90 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 90 -dtype = cutlass_cppgen.DataType.f64 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF64Sm90(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], - element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) - -add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) -add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) -add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) -add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py deleted file mode 100644 index fef6d457a6528a61613d1295877a2b6b8f80fef5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py +++ /dev/null @@ -1,112 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with S8 operands on SM90 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 90 -dtype = cutlass_cppgen.DataType.e4m3 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF8E4M3Sm90(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) - -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -# Test with 1x1x1 clusters -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) - -# Tests with different cluster shapes -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) - -# Tests with warp-specialized ping-pong schedule -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) - -# Tests for SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) -add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) - - -# -# Add a test for E5M2 -# -dtype = cutlass_cppgen.DataType.e5m2 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmF8E5M2Sm90(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) - -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -# Tests with 1x1x1 clusters -add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, - element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py deleted file mode 100644 index 0a002a5fbad80de5f7b29e42db0806469244914c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py +++ /dev/null @@ -1,75 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with mixed operands on SM80 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 -dtype =cutlass_cppgen.DataType.f16 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmMixedSm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], - opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) - -# Test with upcast on A -add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) -add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) - -# Test with upcast on B -add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) -add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py deleted file mode 100644 index e226e23684147cb0a9cd5c1270468eb96c67ba15..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py +++ /dev/null @@ -1,103 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with S8 operands on SM80 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 80 -dtype = cutlass_cppgen.DataType.s8 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmS8Sm80(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmS8Sm80StreamK(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) - -# Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) - -# Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) - -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) - -# Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py deleted file mode 100644 index ec0101f78da3b62b599a5deeb89f5596a7e515ce..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py +++ /dev/null @@ -1,98 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Low-level functionality tests for GEMM with S8 operands on SM90 -""" - -from functools import partial -import logging -import unittest - -import cutlass_cppgen -from cutlass_cppgen.backend.utils.device import device_cc - -from utils import LayoutCombination, add_test_gemm - - -cutlass_cppgen.set_log_level(logging.WARNING) -cc = 90 -dtype = cutlass_cppgen.DataType.s8 - - -@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') -class GemmS8Sm90(unittest.TestCase): - """ - Wrapper class to which tests will be added dynamically in __main__ - """ - pass - - -add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) - -add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) - -# Tests with 1x1x1 clusters -add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) - -# Tests with different cluster shapes -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) - -# Tests with warp-specialized ping-pong schedule -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, - kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) - -# Tests for SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) -add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, - element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py deleted file mode 100644 index 6ffda5b47e37f184c2352f0ee4e737635dbd4147..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py +++ /dev/null @@ -1,423 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from math import prod -import os -import re -import subprocess - -import torch - -from cutlass_library import ( - DataType, - DataTypeSize, - GemmUniversalMode, - LayoutType, - OpcodeClass, - ShortDataTypeNames, - SwizzlingFunctor -) - -from cutlass_cppgen.backend import compiler -from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal -from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation -from cutlass_cppgen.shape import GemmCoord, MatrixCoord -from cutlass_cppgen.utils.datatypes import torch_type - - -class GemmUniversalLauncher: - def __init__( - self, - operation, - seed=2080, - verification=True, - iterations=500, - compiler_mode= "nvcc", - **kwargs, - ) -> None: - self.math_operation = operation.tile_description.math_instruction.math_operation - self.verification = verification - - if compiler_mode == "nvcc": - compiler.nvcc() - elif compiler_mode == "nvrtc": - compiler.nvrtc() - else: - raise Exception(f"Unexpected compiler string {compiler_mode}") - - op_list = [operation] - if operation.arch < 90: - # Split K via Python is currently only supported for pre-SM90 kernels - self.reduction_operation: ReductionOperation = ReductionOperation( - shape=MatrixCoord(4, 32 * operation.C.alignment), - C=operation.C, - element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.epilogue_functor.element_epilogue, - epilogue_functor=operation.epilogue_functor, - count=operation.C.alignment, - ) - op_list.append(self.reduction_operation) - - compiler.add_module(op_list, bypass_cache=False) - - self.operation = operation - - self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) - self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) - self.dtype_C = torch_type(operation.C.element) - self.dtype_D = torch_type(operation.epilogue_functor.element_output) - - element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) - - if element_size == 1: - self.rand_max = 1 - self.rand_min = 0 - elif element_size <= 8: - self.rand_max = 1 - self.rand_min = -1 - elif element_size == 16: - self.rand_max = 4 - self.rand_min = -4 - else: - self.rand_max = 8 - self.rand_min = -8 - - self.seed = seed - - self.compute_type = operation.epilogue_functor.element_epilogue - self.accumulator_type = operation.tile_description.math_instruction.element_accumulator - - def print_problem_size(self, p, mode, batch_count): - if mode == GemmUniversalMode.Gemm: - mode = "Gemm" - elif mode == GemmUniversalMode.Batched: - mode = "GemmBatched" - elif mode == GemmUniversalMode.GemmSplitKParallel: - mode = "GemmSplitKParallel" - print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}") - - def uniform_init(self, shape, dtype, layout): - size = prod(shape) - if dtype.is_floating_point: - # Initialize data in FP32 and call convert to the data type we desire. - # This is a workaround for the following error that occurs when attempting to - # call uniform_ on a tensor with torch.float8_e4m3fn data: - # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' - data = torch.ceil( - torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( - self.rand_min - 0.5, self.rand_max - 0.5) - ).to(dtype) - else: - # PyTorch does not currently support integer-typed matrix multiplications on GPU. - # Fall back to CPU for integer type references. - data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) - - is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) - - if dtype == torch.float64 or dtype == torch.float32 or is_fp8: - data = data.to("cpu") - - data_ref = data.reshape(shape) - - if layout == LayoutType.RowMajor: - data_cutlass = data_ref - else: - data_cutlass = data_ref.transpose(-1, -2).contiguous() - - data_cutlass = data_cutlass.to("cuda") - - # As of this writing, few operations in PyTorch are supported with FP8 data. - # Thus, we perform computation in FP32 for FP8 reference checks. - if is_fp8: - data_ref = data_ref.to(torch.float32) - - return data_cutlass, data_ref - - def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): - # If any tensor is on CPU, place all tensors on CPU unless only - # tensor C is on CPU - # Handle mixed-input cases by casting to the larger data type and overriding - # to whatever the data type of the larger type is - if self.dtype_A != self.dtype_B: - if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: - tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) - else: - tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) - - devices = [x.device.type for x in [tensor_A, tensor_B]] - if tensor_C is not None: - devices.append(tensor_C.device.type) - - if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: - device = torch.device("cpu") - else: - device = tensor_A.device - - tensor_A = tensor_A.to(device) - tensor_B = tensor_B.to(device) - if tensor_C is not None: - tensor_C = tensor_C.to(device) - - dtype = torch_type(self.compute_type) - alpha_torch = torch.tensor([alpha], device=device).to(dtype) - beta_torch = torch.tensor([beta], device=device).to(dtype) - - tmp = tensor_A @ tensor_B - tensor_D_ref = (alpha_torch * tmp) - if tensor_C is not None: - tensor_D_ref += (tensor_C * beta_torch) - return tensor_D_ref.to(self.dtype_D) - - def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): - torch.random.manual_seed(self.seed) - - # Assign an actual batch count in cases where we are not running in batched mode. - # This is to differentiate between the number of split K slices and the batch count, - # which are overloaded within the single `batch_count` variable. - if mode == GemmUniversalMode.Batched: - true_batch_count = batch_count - else: - true_batch_count = 1 - - def transpose(layout): - if layout == LayoutType.RowMajor: - return LayoutType.ColumnMajor - else: - return LayoutType.RowMajor - - tensor_A, tensor_A_ref = self.uniform_init( - (true_batch_count, problem_size.m, problem_size.k), - self.dtype_A, - self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout), - ) - tensor_B, tensor_B_ref = self.uniform_init( - (true_batch_count, problem_size.k, problem_size.n), - self.dtype_B, - self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), - ) - if self.dtype_C is not None: - tensor_C, tensor_C_ref = self.uniform_init( - (true_batch_count, problem_size.m, problem_size.n), - self.dtype_C, - self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), - ) - else: - tensor_C = None - tensor_C_ref = None - - tensor_D, _ = self.uniform_init( - (true_batch_count, problem_size.m, problem_size.n), - self.dtype_D, - self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), - ) - tensor_D = torch.zeros_like(tensor_D) - - if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: - alpha = int(alpha) - beta = int(beta) - - # - # Launch kernel - # - - arguments = GemmArguments( - operation=self.operation, - problem_size=problem_size, - A=tensor_A, - B=tensor_B, - C=tensor_C, - D=tensor_D, - output_op=self.operation.epilogue_type(alpha, beta), - gemm_mode=mode, - split_k_slices=split_k_slices, - batch=batch_count, - ) - - if mode == GemmUniversalMode.GemmSplitKParallel: - reduction_arguments = ReductionArguments( - self.reduction_operation, - problem_size=[problem_size.m, problem_size.n], - partitions=split_k_slices, - workspace=arguments.ptr_D, - destination=tensor_D, - source=tensor_C, - output_op=self.reduction_operation.epilogue_type(alpha, beta), - ) - - self.operation.run(arguments) - - if mode == GemmUniversalMode.GemmSplitKParallel: - self.reduction_operation.run(reduction_arguments) - - passed = True - - if self.verification: - if mode == GemmUniversalMode.GemmSplitKParallel: - reduction_arguments.sync() - - # Free memory allocated by args because we are not - # calling `arguments.sync()` in this case (which will free memory) - arguments.free() - else: - arguments.sync() - tensor_D_ref = self.reference( - problem_size, - tensor_A_ref, - tensor_B_ref, - tensor_C_ref, - alpha, - beta, - ) - - tensor_D_ref = tensor_D_ref.to('cuda') - - if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor: - tensor_D = tensor_D.transpose(-1, -2).contiguous() - - passed = tensor_D.equal(tensor_D_ref) - - try: - assert passed - except AssertionError: - self.print_problem_size(problem_size, mode, batch_count) - del arguments - if mode == GemmUniversalMode.GemmSplitKParallel: - del reduction_arguments - - return passed - - -def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): - passed = True - - minimum_operand_element_size = min( - DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] - ) - opcode_class = operation.tile_description.math_instruction.opcode_class - - if opcode_class == OpcodeClass.Simt: - alignment = 1 - else: - alignment = 128 // minimum_operand_element_size - - alignment_m = alignment - alignment_n = alignment - alignment_k = alignment - - # INT8 alignment constraints - if opcode_class == OpcodeClass.Simt: - A_is_s8 = operation.A.element == DataType.s8 - B_is_s8 = operation.B.element == DataType.s8 - - if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor: - alignment_m = 4 - if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor: - alignment_n = 4 - if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor): - alignment_k = 4 - - threadblock_k = operation.tile_description.threadblock_shape[2] - - assert testcase != "interleaved" - - supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK - - if testcase == "multistage": - modes = [GemmUniversalMode.Gemm] - problem_size_m = [16, 528] - problem_size_n = [16, 528] - problem_size_k = [ - threadblock_k, - threadblock_k * operation.tile_description.stages - + operation.tile_description.math_instruction.instruction_shape[2], - ] - problem_alpha = [1.0] - problem_beta = [0.0] - batch_counts = [1] - else: - modes = [GemmUniversalMode.Gemm] - batch_counts = [1, 2, 3, 5, 7] - if supports_split_k: - modes.append(GemmUniversalMode.GemmSplitKParallel) - - problem_size_m = [alignment_m, 512 - 3 * alignment_m] - problem_size_n = [alignment_n, 512 - 2 * alignment_n] - if operation.tile_description.stages is None: - stages_for_k_calc = 7 - else: - stages_for_k_calc = operation.tile_description.stages - problem_size_k = [ - alignment_k, - threadblock_k * stages_for_k_calc - alignment_k, - threadblock_k * stages_for_k_calc * 3 - alignment_k, - ] - problem_alpha = [1.0] - problem_beta = [2.0] - - testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode) - - for mode in modes: - for m in problem_size_m: - for n in problem_size_n: - for k in problem_size_k: - for batch_count in batch_counts: - for alpha in problem_alpha: - for beta in problem_beta: - # skip very small K problems - if testcase == "universal": - if k // batch_count < 2 * threadblock_k: - continue - - problem_size = GemmCoord(m, n, k) - - if supports_split_k: - split_k_slices = batch_count - else: - split_k_slices = 1 - - overridden_mode = mode - if mode == GemmUniversalMode.Gemm and batch_count > 1: - overridden_mode = GemmUniversalMode.Batched - - passed = testbed.run( - overridden_mode, - problem_size, - batch_count, - split_k_slices, - alpha, - beta, - ) - - if not passed: - return False - - return passed diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py deleted file mode 100644 index bc5e7467b1e0040ce3012ff8541dfbac381bb861..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py +++ /dev/null @@ -1,44 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import pathlib -import unittest - - -if __name__ == '__main__': - loader = unittest.TestLoader() - script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' - tests = loader.discover(script_dir, 'gemm_*.py') - testRunner = unittest.runner.TextTestRunner() - results = testRunner.run(tests) - if not results.wasSuccessful(): - raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py deleted file mode 100644 index 28bba3e922961c96df75f8685e3064ab55cbbc87..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py +++ /dev/null @@ -1,260 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from cutlass_library import SubstituteTemplate - -import cutlass_cppgen -from cutlass_library import ( - DataTypeNames, - EpilogueScheduleSuffixes, - KernelScheduleSuffixes, - LayoutType, - OpcodeClassNames, - ShortDataTypeNames, - ShortLayoutTypeNames -) -from cutlass_cppgen.backend import library - -from gemm_testbed import test_all_gemm - - -class Layout: - """ - Utility class to map transpose and non-transpose terminology to row- and column-major terminology - """ - - T = LayoutType.RowMajor - N = LayoutType.ColumnMajor - - -class LayoutCombination: - """ - Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs - """ - - NNN = (Layout.N, Layout.N, Layout.N) - NNT = (Layout.N, Layout.N, Layout.T) - NTN = (Layout.N, Layout.T, Layout.N) - NTT = (Layout.N, Layout.T, Layout.T) - TNN = (Layout.T, Layout.N, Layout.N) - TNT = (Layout.T, Layout.N, Layout.T) - TTN = (Layout.T, Layout.T, Layout.N) - TTT = (Layout.T, Layout.T, Layout.T) - - -def get_name( - layouts, - alignments, - element_output, - element_accumulator, - element_epilogue, - cluster_shape, - threadblock_shape, - stages, - element_a, - element_b, - element_c, - arch, - opclass, - kernel_schedule=None, - epilogue_schedule=None, - suffix="", -): - """ - Generates a procedural name for a test case. - - :param layouts: indexable container of layouts of A, B, and C operands - :param alignments: indexable container of alignments of A, B, and C operands - :param element_output: data type of the output element - :param element_accumulator: data type used in accumulation - :param element_epilogue: data type used in computing the epilogue - :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched - :param threadblock_shape: indexable container of dimensions of threadblock tiles - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param element_a: data type of operand A - :param element_b: data type of operand B - :param element_c: data type of operand C - :param arch: compute capability of kernel being generated - :type arch: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass_cppgen.OpcodeClass - :param kernel_schedule: kernel_schedule type - :type kernel_schedule: cutlass_cppgen.KernelScheduleType - :param epilogue_schedule: epilogue_schedule type - :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType - :param suffix: additional string to add to the suffix of the name - :type suffix: str - - :return: str - """ - name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" - return SubstituteTemplate( - name_format, - { - "arch": str(arch), - "eA": DataTypeNames[element_a], - "eB": DataTypeNames[element_b], - "eC": DataTypeNames[element_c], - "lA": ShortLayoutTypeNames[layouts[0]], - "lB": ShortLayoutTypeNames[layouts[1]], - "lC": ShortLayoutTypeNames[layouts[2]], - "opclass": OpcodeClassNames[opclass], - "acc": DataTypeNames[element_accumulator], - "cM": str(cluster_shape[0]), - "cN": str(cluster_shape[1]), - "cK": str(cluster_shape[2]), - "tbM": str(threadblock_shape[0]), - "tbN": str(threadblock_shape[1]), - "tbK": str(threadblock_shape[2]), - "stages": str(stages) if stages is not None else "auto", - "aA": str(alignments[0]), - "aB": str(alignments[1]), - "aC": str(alignments[2]), - "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], - "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], - "suffix": "" if suffix is None else suffix, - }, - ) - - -def add_test_gemm( - cls=None, - cc=None, - element=None, - layouts=None, - alignments=None, - element_output=None, - element_accumulator=None, - cluster_shape=None, - threadblock_shape=None, - warp_count=None, - stages=None, - opclass=None, - swizzle=None, - kernel_schedule=None, - epilogue_schedule=None, - compilation_modes=['nvcc', 'nvrtc'], - element_A=None, - element_B=None, - element_C=None): - """ - Create test-running functions with the given specification and set it as a method of ``cls``. - - :param cls: class to which the generated method will be added - :type cls: type - :param cc: compute capability to compile for - :type cc: int - :param element: data type of A and B operands - :type element: cutlass_cppgen.DataType.f16 - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass_cppgen.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass_cppgen.DataType - :param cluster_shape: dimensions of clusters - :type cluster_shape: list or tuple - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass_cppgen.OpcodeClass - :param swizzle: threadblock swizzling functor - :param kernel_schedule: kernel schedule to use - :type kernel_schedule: cutlass_cppgen.KernelScheduleType - :param epilogue_schedule: epilogue schedule to use - :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType - :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') - :type compilation_modes: list, - :param element_A: data type of operand A. If set, overrides ``element`` - :type element_A: cutlass_cppgen.DataType - :param element_B: data type of operand B. If set, overrides ``element`` - :type element_B: cutlass_cppgen.DataType - :param element_C: data type of operand C. If set, overrides ``element`` - :type element_C: cutlass_cppgen.DataType - """ - - if element_A is None: - element_A = element - if element_B is None: - element_B = element - if element_C is None: - element_C = element - if element_output is None: - element_output = element - if element_accumulator is None: - element_accumulator = element - - for compilation_mode in compilation_modes: - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_C, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator, - kernel_cc=cc) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - - td = plan.tile_descriptions()[0] - - if warp_count is not None: - td.warp_count = warp_count - td.threadblock_shape = threadblock_shape - td.stages = stages - td.cluster_shape = cluster_shape - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) - - element_epilogue = element_accumulator - name = get_name( - layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, - element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, - stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, - kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') - - setattr(cls, name, run) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py deleted file mode 100644 index f550c394812c7fede55070e4c99c4471a69c2f88..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py +++ /dev/null @@ -1,57 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Tests for a successful installation of the CUTLASS Python interface -""" - -import os -import unittest - -import cutlass_cppgen -import cutlass_library - - -class InstallationTest(unittest.TestCase): - def test_cutlass_source_paths(self): - """ - Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages - """ - src_file = 'include/cutlass/cutlass.h' - library_file = os.path.join(cutlass_library.source_path, src_file) - cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) - assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." - assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." - - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py deleted file mode 100644 index 2b5d46d45d617198a46bec85cd7218cb5431a7b1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py +++ /dev/null @@ -1,284 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Tests the high-level Conv2d interface -""" - -from math import ceil -import unittest - -import cutlass_cppgen -import cutlass_cppgen.utils.datatypes as datatypes -from cutlass_cppgen.backend.utils.device import device_cc -from utils import ExpectException -import os - - -class Conv2dEquivalence: - """ - Helper class for testing the equivalence of different constructions of the Conv2d interface - """ - def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, - alignment_A, alignment_B, alignment_C): - - self.element_A = element_A - self.element_B = element_B - self.element_C = element_C - self.element_D = element_D - self.element_accumulator = element_accumulator - self.alignment_A = alignment_A - self.alignment_B = alignment_B - self.alignment_C = alignment_C - - self.conv_kind = conv_kind - - self.plan = cutlass_cppgen.op.Conv2d( - kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, - element_D=element_D, element_accumulator=element_accumulator) - - self.op = self.plan.construct( - alignment_A=self.alignment_A, alignment_B=self.alignment_B, - alignment_C=self.alignment_C) - - def _plans_equal(self, other_plan) -> bool: - """ - Compares whether two plans are equal - - :param other_plan: plan to compare against the default Conv2d - :type other_plan: cutlass_cppgen.op.Conv2d - - :return: whether `other_plan` is equivalent to `self.plan` - :rtype: bool - """ - other_op = other_plan.construct( - alignment_A=self.alignment_A, alignment_B=self.alignment_B, - alignment_C=self.alignment_C) - - return self.op.rt_module.emit() == other_op.rt_module.emit() - - def generic_test(self): - """ - Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types - and layouts for constructing the Conv2d interface - """ - if not datatypes.is_numpy_available(): - return - - # Test when specifying all parameters - plan_other = cutlass_cppgen.op.Conv2d( - kind=self.conv_kind, - element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, - element_D=self.element_D, element_accumulator=self.element_accumulator) - assert self._plans_equal(plan_other) - - # Test when specifying all parameters but A - plan_other = cutlass_cppgen.op.Conv2d( - kind=self.conv_kind, - element_B=self.element_B, element_C=self.element_C, - element_D=self.element_D, element_accumulator=self.element_accumulator, - element=self.element_A) - assert self._plans_equal(plan_other) - - # Test when specifying all parameters but A and B as tensors using generic element and output - plan_other = cutlass_cppgen.op.Conv2d( - kind=self.conv_kind, - element_C=self.element_C, - element_D=self.element_D, element_accumulator=self.element_accumulator, - element=self.element_A) - assert self._plans_equal(plan_other) - - # Test without explicit accumulator. Only run if the type of C and the accumulator are equal - if self.element_C == self.element_accumulator: - plan_other = cutlass_cppgen.op.Conv2d( - kind=self.conv_kind, - element_C=self.element_C, - element_D=self.element_D, - element=self.element_A) - assert self._plans_equal(plan_other) - - # Test with only the generic types. Only rune if the types of A, B, C, and D are the same - if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D - and self.element_A == self.element_accumulator): - plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) - assert self._plans_equal(plan_other) - - def numpy_test(self): - """ - Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend - """ - if not datatypes.is_numpy_available(): - return - - import numpy as np - type_A = datatypes.numpy_type(self.element_A) - type_B = datatypes.numpy_type(self.element_B) - type_C = datatypes.numpy_type(self.element_C) - type_D = datatypes.numpy_type(self.element_D) - type_accum = datatypes.numpy_type(self.element_accumulator) - - size = (2, 2) - A = np.zeros(size, dtype=type_A) - B = np.zeros(size, dtype=type_B) - C = np.zeros(size, dtype=type_C) - D = np.zeros(size, dtype=type_D) - - return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) - - def torch_test(self): - """ - Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend - """ - if not datatypes.is_torch_available(): - return - - import torch - type_A = datatypes.torch_type(self.element_A) - type_B = datatypes.torch_type(self.element_B) - type_C = datatypes.torch_type(self.element_C) - type_D = datatypes.torch_type(self.element_D) - type_accum = datatypes.torch_type(self.element_accumulator) - - size = (2, 2) - - A = torch.empty(size, dtype=type_A) - B = torch.empty(size, dtype=type_B) - C = torch.empty(size, dtype=type_C) - D = torch.empty(size, dtype=type_D) - - return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) - - def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): - # Test when specifying all parameters via tensors - plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) - assert self._plans_equal(plan_np) - - # Test when specifying all parameters but A as tensors - plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) - assert self._plans_equal(plan_np) - - # Test when specifying all parameters but A and B as tensors and using generic element and output - if type_A == type_B: - plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) - assert self._plans_equal(plan_np) - - # Test without explicit accumulator. Only run if the type of C and the accumulator. - if type_C == type_accum: - plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) - assert self._plans_equal(plan_np) - - # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. - if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): - plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) - assert self._plans_equal(plan_np) - - def test_all(self): - """ - Runs all tests on the Gemm interface - """ - self.generic_test() - self.numpy_test() - self.torch_test() - - -@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') -class ConvEquivalenceTest(unittest.TestCase): - """ - Tests the equivalence of different constructions of the Conv2d interface - """ - pass - -type2alignment = { - cutlass_cppgen.DataType.f16: 8, - cutlass_cppgen.DataType.f32: 4 -} - -def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): - - test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" - - def run(self): - conv2d_eq = Conv2dEquivalence( - conv_kind=conv_kind, - element_A=element_A, element_B=element_B, - element_C=element_C, element_D=element_D, - element_accumulator=element_accumulator, - alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], - alignment_C=type2alignment[element_C] - ) - conv2d_eq.test_all() - - setattr(ConvEquivalenceTest, test_name, run) - -for conv_kind in ["fprop", "wgrad", "dgrad"]: - for types in [ - [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], - [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], - [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], - [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], - [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] - ]: - add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) - - -@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') -class Conv2dErrorTests(unittest.TestCase): - """ - Tests various error scenarios that arise with the high-level Gemm interface - """ - - def test_alignment(self): - """ - Tests case in which the alignment specified is unsupported - """ - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) - - with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): - op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) - - def test_invalid_tile_description(self): - """ - Tests scenarios in which an invalid tile description is provided for a given CC - """ - plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) - - td = plan.tile_descriptions()[0] - td.threadblock_shape=[17, 32, 5] - - plan.tile_description = td - with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): - plan.compile() - # Clean up the error message - os.remove("./cutlass_python_compilation_device_error.txt") - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py deleted file mode 100644 index e7d67f4d07f01b0936ff5796bfb6fe4c98b5c031..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py +++ /dev/null @@ -1,254 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Test the EVT interface -""" - -import numpy as np -import unittest - -import cutlass_cppgen -from cutlass_cppgen import LayoutType, Tensor -from cutlass_cppgen.backend.utils.device import device_cc -from cutlass_cppgen.epilogue import reshape, permute - -from utils import ExpectException - - -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class EVTErrorTests(unittest.TestCase): - """ - Tests various error scenarios that arise with the EVT interface - """ - @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'") - def test_root_not_d(self): - """ - Test when "D" does not exist in Sm90 EVT - """ - def evt_root_not_d(accum, alpha): - F = accum * alpha - return F - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "alpha": 1.2, - "F": self.fake_tensor(np.float16, (6, 512, 512)) - } - - with ExpectException(device_cc() == 90, - "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " - "but the variable 'D' is not found in the return values.", True): - - cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) - - def test_no_accum(self): - """ - Test when "accum" is not in input arguments - """ - def evt_no_accum(alpha, C): - D = alpha * C - return D - - example_tensors = { - "C": self.fake_tensor(np.float16, (6, 512, 512)), - "alpha": 1.2, - "D": self.fake_tensor(np.float16, (6, 512, 512)) - } - - with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): - cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) - - @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") - def test_too_much_shared_memory(self): - """ - Test when the epilogue consumes too much shared memory - """ - def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): - D1 = accum + C1 - D2 = D1 + C2 - D3 = D2 + C3 - D4 = D3 + C4 - D5 = D4 + C5 - D6 = D5 + C6 - D7 = D6 + C7 - D = D7 + C8 - return D, D1, D2, D3, D4, D5, D6, D7 - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "C1": self.fake_tensor(np.float16, (6, 512, 512)), - "C2": self.fake_tensor(np.float16, (6, 512, 512)), - "C3": self.fake_tensor(np.float16, (6, 512, 512)), - "C4": self.fake_tensor(np.float16, (6, 512, 512)), - "C5": self.fake_tensor(np.float16, (6, 512, 512)), - "C6": self.fake_tensor(np.float16, (6, 512, 512)), - "C7": self.fake_tensor(np.float16, (6, 512, 512)), - "C8": self.fake_tensor(np.float16, (6, 512, 512)), - "D1": self.fake_tensor(np.float16, (6, 512, 512)), - "D2": self.fake_tensor(np.float16, (6, 512, 512)), - "D3": self.fake_tensor(np.float16, (6, 512, 512)), - "D4": self.fake_tensor(np.float16, (6, 512, 512)), - "D5": self.fake_tensor(np.float16, (6, 512, 512)), - "D6": self.fake_tensor(np.float16, (6, 512, 512)), - "D7": self.fake_tensor(np.float16, (6, 512, 512)), - "D": self.fake_tensor(np.float16, (6, 512, 512)) - } - - epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) - - plan = cutlass_cppgen.op.Gemm( - element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, - element_accumulator=np.float32 - ) - - with ExpectException(True, - "RuntimeError: The epilogue consumes too much shared memory. " - "No valid tile description is found in the generator.", True): - plan.epilogue_visitor = epilogue_visitor - - def test_not_ssa(self): - """ - Test when the epilogue is not in SSA - """ - def evt_redefine(accum, C, alpha): - F = accum + C - F = F * alpha - D = F - return D, F - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "C": self.fake_tensor(np.float16, (6, 512, 512)), - "alpha": 1.5, - "D": self.fake_tensor(np.float16, (6, 512, 512)), - "F": self.fake_tensor(np.float16, (6, 512, 512)) - } - - with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): - cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) - - def evt_undefine(accum, alpha): - F = accum + C - D = F * alpha - return D, F - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "alpha": 1.5, - "D": self.fake_tensor(np.float16, (6, 512, 512)), - "F": self.fake_tensor(np.float16, (6, 512, 512)) - } - - with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): - cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) - - def test_missing_example_tensor(self): - """ - Test when the example tensor of an input/output variable is not provided - """ - def evt_missing_example_tensor(accum, C): - D = accum + C - return D - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "C": self.fake_tensor(np.float16, (6, 512, 512)), - } - - with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): - cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "D": self.fake_tensor(np.float16, (6, 512, 512)), - } - - with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): - cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) - - def test_return_expression(self): - """ - Test when the return value is an expression - """ - def evt_return_expr(accum, C): - return accum + C - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 512)), - "C": self.fake_tensor(np.float16, (6, 512, 512)), - } - - with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): - cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) - - def test_incompatible_shape(self): - """ - Test when the shape of example tensors are incompatible - """ - def evt_incompatible_shape(accum, C): - D = accum + C - return D - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 256, 512)), - "C": self.fake_tensor(np.float16, (6, 512, 512)), - "D": self.fake_tensor(np.float16, (6, 512, 512)) - } - - with ExpectException(True, - "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): - cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) - - def test_no_matching_impl(self): - def evt_no_matching_impl(accum, bias): - D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1)) - return D - - example_tensors = { - "accum": self.fake_tensor(np.float16, (6, 512, 256)), - "bias": self.fake_tensor(np.float16, (16, 32)), - "D": self.fake_tensor(np.float16, (6, 512, 256)) - } - - with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): - cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) - # - # Helper functions - # - - def fake_tensor(self, element, shape): - return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py deleted file mode 100644 index 2913d5933f5342cc58b4f252657a724d2c7692da..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py +++ /dev/null @@ -1,354 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Tests the high-level GEMM interface -""" - -from math import ceil -import unittest - -import cutlass_cppgen -import cutlass_cppgen.utils.datatypes as datatypes -from cutlass_cppgen.backend.utils.device import device_cc -from utils import ExpectException - - -class GemmEquivalence: - """ - Helper class for testing the equivalence of different constructions of the Gemm interface - """ - def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, - layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): - self.element_A = element_A - self.element_B = element_B - self.element_C = element_C - self.element_D = element_D - self.element_accumulator = element_accumulator - self.layout_A = layout_A - self.layout_B = layout_B - self.layout_C = layout_C - self.alignment_A = alignment_A - self.alignment_B = alignment_B - self.alignment_C = alignment_C - self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, - element_D=element_D, element_accumulator=element_accumulator, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) - self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - - def _plans_equal(self, other_plan) -> bool: - """ - Compares whether two plans are equal - - :param other_plan: plan to compare against the default GEMM - :type other_plan: cutlass_cppgen.op.Gemm - - :return: whether `other_plan` is equivalent to `self.plan` - :rtype: bool - """ - other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) - - # Compare whether the operations are equal by comparing the C++ code that would be emitted for them - return self.op.rt_module.emit() == other_op.rt_module.emit() - - def generic_test(self): - """ - Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types - and layouts for constructing the Gemm interface - """ - if not datatypes.is_numpy_available(): - return - - # Test when specifying all parameters - plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, - element_D=self.element_D, element_accumulator=self.element_accumulator, - layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) - assert self._plans_equal(plan_other) - - # Test when specifying all parameters but A - plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, - element_D=self.element_D, element_accumulator=self.element_accumulator, - layout_B=self.layout_B, layout_C=self.layout_C, - element=self.element_A, layout=self.layout_A) - assert self._plans_equal(plan_other) - - # Test when specifying all parameters but A and B as tensors and using generic element and output - # Only run this test if the layouts and types for A and B are equal. - if self.element_A == self.element_B and self.layout_A == self.layout_B: - plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, - layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) - assert self._plans_equal(plan_other) - - # Test without explicit accumulator. Only run if the type of C and the accumulator. - if self.element_C == self.element_accumulator: - plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, - element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, - layout_C=self.layout_C) - assert self._plans_equal(plan_other) - - # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. - if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D - and self.element_A == self.element_accumulator and - self.layout_A == self.layout_B and self.layout_A == self.layout_C): - plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) - assert self._plans_equal(plan_other) - - def numpy_test(self): - """ - Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend - """ - if not datatypes.is_numpy_available(): - return - - import numpy as np - type_A = datatypes.numpy_type(self.element_A) - type_B = datatypes.numpy_type(self.element_B) - type_C = datatypes.numpy_type(self.element_C) - type_D = datatypes.numpy_type(self.element_D) - type_accum = datatypes.numpy_type(self.element_accumulator) - - layout_to_order = { - cutlass_cppgen.LayoutType.RowMajor: 'C', - cutlass_cppgen.LayoutType.ColumnMajor: 'F' - } - size = (2, 2) - A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) - B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) - C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) - D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) - - # Test when specifying all parameters via tensors - plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) - assert self._plans_equal(plan_np) - - # Test when specifying all parameters but A as tensors - plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) - assert self._plans_equal(plan_np) - - # Test when specifying all parameters but A and B as tensors and using generic element and output - # Only run this test if the layouts and types for A and B are equal. - if type_A == type_B and self.layout_A == self.layout_B: - plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) - assert self._plans_equal(plan_np) - - # Test without explicit accumulator. Only run if the type of C and the accumulator. - if type_C == type_accum: - plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) - assert self._plans_equal(plan_np) - - # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. - if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and - self.layout_A == self.layout_B and self.layout_A == self.layout_C): - plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) - assert self._plans_equal(plan_np) - - def test_all(self): - """ - Runs all tests on the Gemm interface - """ - self.generic_test() - self.numpy_test() - - -class GemmEquivalenceTest(unittest.TestCase): - """ - Tests the equivalence of different constructions of the Gemm interface - """ - @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") - def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): - gemm_eq = GemmEquivalence( - element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, - layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, - alignment_A=8, alignment_B=8, alignment_C=8) - gemm_eq.test_all() - - @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") - def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): - gemm_eq = GemmEquivalence( - element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, - layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, - alignment_A=8, alignment_B=8, alignment_C=8) - gemm_eq.test_all() - - @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") - def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): - gemm_eq = GemmEquivalence( - element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, - element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, - layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, - alignment_A=8, alignment_B=8, alignment_C=8) - gemm_eq.test_all() - - @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") - def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): - gemm_eq = GemmEquivalence( - element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, - element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, - layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, - alignment_A=1, alignment_B=1, alignment_C=1) - gemm_eq.test_all() - - -class GemmErrorTests(unittest.TestCase): - """ - Tests various error scenarios that arise with the high-level Gemm interface - """ - - def test_alignment(self): - """ - Tests case in which the alignment specified is unsupported - """ - plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) - - with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): - op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) - - def test_tensorop_availability(self): - """ - Tests case in which only SIMT operations are available but TensorOp is requested - """ - cc = device_cc() - - # F64 Tensor Core operations are only avaiable on certain devices - supports_tensorop_f64 = cc in [80, 89, 90] - plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) - - error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' - with ExpectException(not supports_tensorop_f64, error_msg): - plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp - - expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt - assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' - - @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") - def test_opclass_switch(self): - """ - Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) - """ - plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) - assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp - - # Ensure that all tile descriptions have opclass of TensorOp - for td in plan.tile_descriptions(): - assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp - - plan.opclass = cutlass_cppgen.OpcodeClass.Simt - - # Ensure that all tile descriptions have opclass of Simt - for td in plan.tile_descriptions(): - assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt - - def test_invalid_tile_description(self): - """ - Tests scenarios in which an invalid tile description is provided for a given CC - """ - cc = device_cc() - plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) - td = plan.tile_descriptions()[0] - stages = td.stages - - # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage - # count should be used - with ExpectException(cc < 90, f'Requested zero stages'): - td.stages = 0 - plan.construct(td) - - if cc < 90: - with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): - td.stages = 3 - plan.construct(td) - elif cc == 90: - original_kschedule = td.kernel_schedule - original_eschedule = td.epilogue_schedule - with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): - td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized - td.stages = 3 - plan.construct(td) - # Reset schedules - td.kernel_schedule = original_kschedule - td.epilogue_schedule = original_eschedule - elif cc in [100, 101, 103]: - with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): - td.stages = 3 - plan.construct(td) - - with ExpectException(True, f'Requested too many stages'): - td.stages = 100 - plan.construct(td) - - # Reset stage count - td.stages = stages - - cluster_shape = td.cluster_shape - with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): - td.cluster_shape = [2, 1, 1] - plan.construct(td) - - # Reset cluster shape - td.cluster_shape = cluster_shape - - with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): - td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized - plan.construct(td) - - with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): - td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto - plan.construct(td) - - with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): - td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto - td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized - plan.construct(td) - - with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): - td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative - td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative - td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK - plan.construct(td) - - # Ensure that all returned tile descriptions are unique - ops = {} - for i, td in enumerate(plan.tile_descriptions()): - op = plan.construct(td) - code_str = op.rt_module.emit() - if code_str in ops: - conflicting_td = ops[code_str] - assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py deleted file mode 100644 index 9f93ca26e2d79a15dab4dd0045836ebd9fe62757..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Helper functions & classes for interface test -""" -class ExpectException: - """ - Utility class to assert that an exception was raised when expected - - Example: - - .. highlight:: python - .. code-block:: python - - with ExceptionExpected(True, 'Division by zero'): - x = 1.0 / 0.0 - - :param exception_expected: whether an exception is expected to be raised - :type exception_expected: bool - :param message: message to print if an exception is raised when not expected or vice versa - :type message: str - """ - def __init__(self, exception_expected: bool, message: str = '', verify_msg=False): - self.exception_expected = exception_expected - self.message = message - self.verify_msg = verify_msg - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, traceback): - exception_raised = exc_type is not None - assert self.exception_expected == exception_raised, self.message - if self.verify_msg: - exc_message = f"{exc_type.__name__}: {exc_val}" - assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}" - - # Suppress the exception - return True diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py deleted file mode 100644 index b7cdc421ccffffeb7bd1696aaf9916330a6625ca..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py +++ /dev/null @@ -1,75 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility script for discovering and running all PyCuTe tests -""" - -import argparse -import logging -import pathlib -import unittest - - -def numeric_log_level(log_level: str) -> int: - """ - Converts the string identifier of the log level into the numeric identifier used - in setting the log level - - :param x: string representation of log level (e.g., 'INFO', 'DEBUG') - :type x: str - - :return: numeric representation of log level - :rtype: int - """ - numeric_level = getattr(logging, log_level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError(f"Invalid log level: {log_level}") - return numeric_level - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, - help='Logging level to be used by the generator script') - args = parser.parse_args() - - # Set the logging level based on the user-provided `--log-level` command-line option - logging.basicConfig(level=args.log_level) - - loader = unittest.TestLoader() - script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' - tests = loader.discover(script_dir, "test_*.py") - test_runner = unittest.runner.TextTestRunner() - results = test_runner.run(tests) - if not results.wasSuccessful(): - raise Exception("Test cases failed") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py deleted file mode 100644 index d4330377cab7079ea16422f194ddf4f2403ea507..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py +++ /dev/null @@ -1,95 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.coalesce -""" - -import logging -import unittest - -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestCoalesce(unittest.TestCase): - def helper_test_coalesce(self, layout): - layoutR = coalesce(layout) - - _LOGGER.debug(f"{layout} => {layoutR}") - - self.assertEqual(size(layoutR), size(layout)) - - for i in range(size(layout)): - self.assertEqual(layoutR(i), layout(i)) - - def test_coalesce(self): - layout = Layout(1,0) - self.helper_test_coalesce(layout) - - layout = Layout(1,1) - self.helper_test_coalesce(layout) - - layout = Layout((2,4)) - self.helper_test_coalesce(layout) - - layout = Layout((2,4,6)) - self.helper_test_coalesce(layout) - - layout = Layout((2,4,6), (1,6,2)) - self.helper_test_coalesce(layout) - - layout = Layout((2,1,6), (1,7,2)) - self.helper_test_coalesce(layout) - - layout = Layout((2,1,6), (4,7,8)) - self.helper_test_coalesce(layout) - - layout = Layout((2,(4,6))) - self.helper_test_coalesce(layout) - - layout = Layout((2,4), (4,1)) - self.helper_test_coalesce(layout) - - layout = Layout((2,4,6), (24,6,1)) - self.helper_test_coalesce(layout) - - layout = Layout((2,1,3), (2,4,4)) - self.helper_test_coalesce(layout) - - layout = Layout(((2,2),(2,2)), ((1,4),(8,32))) - self.helper_test_coalesce(layout) - - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py deleted file mode 100644 index 5a8684a55b19c90eae11ddd1cca011c2ff8270b5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py +++ /dev/null @@ -1,92 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.complement -""" - -import logging -import unittest - -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestComplement(unittest.TestCase): - def helper_test_complement(self, layout): - layoutR = complement(layout) - - _LOGGER.debug(f"{layout} => {layoutR}") - - # Post-condition: test disjointness of the codomains - for a in range(size(layout)): - for b in range(size(layoutR)): - assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) - - def test_complement(self): - test = Layout(1,0) - self.helper_test_complement(test) - - test = Layout(1,1) - self.helper_test_complement(test) - - test = Layout(4,0) - self.helper_test_complement(test) - - test = Layout((2,4),(1,2)) - self.helper_test_complement(test) - - test = Layout((2,3),(1,2)) - self.helper_test_complement(test) - - test = Layout((2,4),(1,4)) - self.helper_test_complement(test) - - test = Layout((2,4,8),(8,1,64)) - self.helper_test_complement(test) - - test = Layout(((2,2),(2,2)),((1,4),(8,32))) - self.helper_test_complement(test) - - test = Layout((2,(3,4)),(3,(1,6))) - self.helper_test_complement(test) - - test = Layout((4,6),(1,6)) - self.helper_test_complement(test) - - test = Layout((4,10),(1,10)) - self.helper_test_complement(test) - - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py deleted file mode 100644 index 6c27eb7fe6cbb7bbbea7bd644ac8e64a2fc853c9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py +++ /dev/null @@ -1,213 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.composition -""" - -import logging -import unittest - -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestComposition(unittest.TestCase): - def helper_test_composition(self, layoutA, layoutB): - layoutR = composition(layoutA, layoutB) - - _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}") - - # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. - - # Test that R(c) = A(B(c)) for all coordinates c in layoutR - for i in range(size(layoutR)): - self.assertEqual(layoutR(i), layoutA(layoutB(i))) - - def test_composition(self): - layoutA = Layout(1,0) - layoutB = Layout(1,0) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout(1,0) - layoutB = Layout(1,1) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout(1,1) - layoutB = Layout(1,0) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout(1,1) - layoutB = Layout(1,1) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((4)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4), (2)) - layoutB = Layout((4)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((4), (2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4), (0)) - layoutB = Layout((4)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((4), (0)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((1), (0)) - layoutB = Layout((4)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((1), (0)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4), (2)) - layoutB = Layout((2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4)) - layoutB = Layout((2), (2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4), (2)) - layoutB = Layout((2), (2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((12)) - layoutB = Layout((4,3)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((12), (2)) - layoutB = Layout((4,3)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((12)) - layoutB = Layout((4,3), (3,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((12), (2)) - layoutB = Layout((4,3), (3,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((12)) - layoutB = Layout((2,3), (2,4)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3)) - layoutB = Layout((4,3)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3)) - layoutB = Layout((12)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3)) - layoutB = Layout((6), (2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3)) - layoutB = Layout((6,2), (2,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3), (3,1)) - layoutB = Layout((4,3)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3), (3,1)) - layoutB = Layout((12)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3), (3,1)) - layoutB = Layout((6), (2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,3), (3,1)) - layoutB = Layout((6,2), (2,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((8,8)) - layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((8,8), (8,1)) - layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) - layoutB = Layout(8, 4) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout(((4,2)), ((1,16))) - layoutB = Layout((4,2), (2,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((2,2), (2,1)) - layoutB = Layout((2,2), (2,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,8,2)) - layoutB = Layout((2,2,2), (2,8,1)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,8,2), (2,8,1)) - layoutB = Layout((2,2,2), (1,8,2)) - self.helper_test_composition(layoutA, layoutB) - - layoutA = Layout((4,8,2), (2,8,1)) - layoutB = Layout((4,2,2), (2,8,1)) - self.helper_test_composition(layoutA, layoutB) - - # Pre-coalesced LHS - layoutA = Layout((4,6,8),(1,4,7)) - layoutB = Layout((6),(1)) - self.helper_test_composition(layoutA, layoutB) - - # Mid-layout truncation - layoutA = Layout((4,6,8,10),(2,3,5,7)) - layoutB = Layout(6,12) - self.helper_test_composition(layoutA, layoutB) - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py deleted file mode 100644 index 0dbf443c9725735b0051d0a225a55eece9c663a8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py +++ /dev/null @@ -1,80 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.int_tuple -""" - -import unittest - -from pycute import * - - -class TestIntTuple(unittest.TestCase): - def test_product(self): - self.assertEqual(product(2), 2) - - self.assertEqual(product((3,2)), 6) - - self.assertEqual(product(product(((2,3),4))), 24) - - def test_inner_product(self): - self.assertEqual(inner_product(2, 3), 6) - - self.assertEqual(inner_product((1,2), (3,2)), 7) - - self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15) - - def test_shape_div(self): - self.assertEqual(shape_div((3,4), 6), (1,2)) - - self.assertEqual(shape_div((3,4), 12), (1,1)) - - self.assertEqual(shape_div((3,4), 36), (1,1)) - - self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2)) - - self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2))) - - def test_prefix_product(self): - self.assertEqual(prefix_product(2), 1) - - self.assertEqual(prefix_product((3,2)), (1,3)) - - self.assertEqual(prefix_product((3,2,4)), (1,3,6)) - - self.assertEqual(prefix_product(((2,3),4)), ((1,2),6)) - - self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))), - ((1,2),(6,12,12),(24,120,240))) - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py deleted file mode 100644 index a6501fd6c7c6fc5a518e4d22bf93dc0e4746a8ba..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py +++ /dev/null @@ -1,87 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.left_inverse -""" - -import logging -import unittest - -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestLeftInverse(unittest.TestCase): - def helper_test_left_inverse(self, layout): - inv_layout = left_inverse(layout) - - _LOGGER.debug(f"{layout} => {inv_layout}") - - for i in range(size(layout)): - self.assertEqual(inv_layout(layout(i)), i) - - def test_left_inverse(self): - test = Layout(1,0) - self.helper_test_left_inverse(test) - - test = Layout((1,1),(0,0)) - self.helper_test_left_inverse(test) - - test = Layout(1,1) - self.helper_test_left_inverse(test) - - test = Layout(4,1) - self.helper_test_left_inverse(test) - - test = Layout(4,2) - self.helper_test_left_inverse(test) - - test = Layout((8,4),(1,8)) - self.helper_test_left_inverse(test) - - test = Layout((8,4),(4,1)) - self.helper_test_left_inverse(test) - - test = Layout((2,4,6),(1,2,8)) - self.helper_test_left_inverse(test) - - test = Layout((2,4,6),(4,1,8)) - self.helper_test_left_inverse(test) - - test = Layout((4,2),(1,16)) - self.helper_test_left_inverse(test) - - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py deleted file mode 100644 index 2ed9759d7808da8087fe9c76761d2dd9eaeab08b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py +++ /dev/null @@ -1,96 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.left_inverse -""" - -import logging -import unittest - -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestRightInverse(unittest.TestCase): - def helper_test_right_inverse(self, layout): - inv_layout = right_inverse(layout) - - _LOGGER.debug(f"{layout} => {inv_layout}") - - for i in range(size(inv_layout)): - self.assertEqual(layout(inv_layout(i)), i) - - def test_right_inverse(self): - test = Layout(1,0) - self.helper_test_right_inverse(test) - - test = Layout((1,1),(0,0)) - self.helper_test_right_inverse(test) - - test = Layout((3,7),(0,0)) - self.helper_test_right_inverse(test) - - test = Layout(1,1) - self.helper_test_right_inverse(test) - - test = Layout(4,0) - self.helper_test_right_inverse(test) - - test = Layout(4,1) - self.helper_test_right_inverse(test) - - test = Layout(4,2) - self.helper_test_right_inverse(test) - - test = Layout((2,4),(0,2)) - self.helper_test_right_inverse(test) - - test = Layout((8,4),(1,8)) - self.helper_test_right_inverse(test) - - test = Layout((8,4),(4,1)) - self.helper_test_right_inverse(test) - - test = Layout((2,4,6),(1,2,8)) - self.helper_test_right_inverse(test) - - test = Layout((2,4,6),(4,1,8)) - self.helper_test_right_inverse(test) - - test = Layout((4,2),(1,16)) - self.helper_test_right_inverse(test) - - -if __name__ == "__main__": - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py deleted file mode 100644 index 9eb99a4833529e18fa22d65a235ce80dad372365..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py +++ /dev/null @@ -1,59 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Unit tests for pycute.typing -""" - -import logging -import unittest -from pycute import * - -_LOGGER = logging.getLogger(__name__) - - -class TestTyping(unittest.TestCase): - def helper_test_typing(self, _cls, _obj, cls, expected: bool): - _LOGGER.debug(f"issubclass({_cls}, {cls})") - _LOGGER.debug(f"isinstance({_obj}, {cls})") - - self.assertEqual(expected, issubclass(_cls, cls)) - self.assertEqual(expected, isinstance(_obj, cls)) - - def test_typing(self): - self.helper_test_typing(int, 1, Integer, True) - self.helper_test_typing(float, 1., Integer, False) - self.helper_test_typing(str, 'hi', Integer, False) - self.helper_test_typing(bool, False, Integer, False) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h deleted file mode 100644 index 86b7823785a9f2a957cf505740d6cfde45ccfef1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h +++ /dev/null @@ -1,102 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once -#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */ - -#pragma nv_diag_suppress boolean_controlling_expr_is_constant -#include -#pragma nv_diag_warning boolean_controlling_expr_is_constant -#pragma warning( disable : 4503) - -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Gets a CUDA device -cudaDeviceProp GetCudaDevice(); - -/// Prints device properties -std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Sets flags for Unit test -void FilterArchitecture(); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order -// of problem sizes run by CUTLASS unit tests -int CutlassUnitTestProblemCount(); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// active test macro -#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ - TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ - -// disabled test macro -#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ - TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} - -#if CUTLASS_TEST_LEVEL == 0 -#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#elif CUTLASS_TEST_LEVEL == 1 -#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#else -#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -#endif - -#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) -#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false -#endif - -#if (__CUDACC_VER_MAJOR__ >= 12) - #define CUDA_12_0_SM90_FEATURES_SUPPORTED true -#else - #define CUDA_12_0_SM90_FEATURES_SUPPORTED false -#endif - -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h deleted file mode 100644 index 3035e9862bcb79b749b4cbc4a74341bceac9c598..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h +++ /dev/null @@ -1,907 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Helper to construct cached name for -*/ -#pragma once - -#include -#include -#include -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -#include "cutlass/conv/conv3d_problem_size.h" -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "thrust/universal_vector.h" - -#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS -#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test::conv::device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Result of a test -struct CachedTestKey { - - std::string op; ///< Concatenated string representation of operation performed - std::string problem; ///< Concatenated string representation of problem description - std::string types; ///< Concatenated string representation of operand types - uint32_t A; ///< Hashed result of tensor A - uint32_t B; ///< Hashed result of tensor B - uint32_t C; ///< Hashed result of tensor C - - // - // Methods - // - inline CachedTestKey(): A(), B(), C() { } - - inline CachedTestKey( - std::string op, ///< Concatenated string representation of operation performed - std::string problem, ///< Concatenated string representation of problem description - std::string types, ///< Concatenated string representation of operand types - uint32_t A, ///< Hashed result of tensor A - uint32_t B, ///< Hashed result of tensor B - uint32_t C ///< Hashed result of tensor C - ): - op(op), problem(problem), types(types), A(A), B(B), C(C) - { } - - /// Checks for equality of the problem - bool operator==(CachedTestKey const &rhs) const { - return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { - - in >> result.op; - in >> result.problem; - in >> result.types; - in >> result.A; - in >> result.B; - in >> result.C; - - return in; -} - -inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { - - out << result.op << " "; - out << result.problem << " "; - out << result.types << " "; - out << result.A << " "; - out << result.B << " "; - out << result.C << " "; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct CachedTestResult { - uint32_t D; - // - // Methods - // - - CachedTestResult(): D() - { } - - CachedTestResult(uint32_t D): D(D) - { } - - operator bool() const { - return bool(D); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { - in >> result.D; - return in; -} - -inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { - out << result.D; - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct CachedTestResultListing { - - std::list> results; - - // - // Methods - // - - inline CachedTestResultListing(std::string const &path) { - std::ifstream file(path); - - while (file.good()) { - CachedTestKey key; - file >> key; - - CachedTestResult result; - file >> result; - - if (result) { - results.push_back(std::make_pair(key, result)); - } - } - } - - /// Returns the cached result - std::pair find(CachedTestKey const &rhs) const { - for (auto const & result : results) { - if (result.first == rhs) { - return std::make_pair(true, result.second); - } - } - return std::make_pair(false, CachedTestResult()); - } - - /// Appends an entry - void append(CachedTestKey const &key, CachedTestResult const &result) { - if (result) { - results.push_back(std::make_pair(key, result)); - } - } - - /// Writes the entire listing to a file - bool write(std::string const &path) { - std::ofstream file(path); - if (!file.good()) { - return false; - } - - for (auto const &result : results) { - file << result.first << result.second << std::endl; - } - - return true; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct ScalarEncoder { - Element scalar; - - ScalarEncoder(Element s): scalar(s) { } - - std::string str() const { - std::stringstream ss; - Element s = scalar; - if (s < Element()) { - s = -s; - ss << "n"; - } - ss << s; - return ss.str(); - } -}; - -template -ScalarEncoder EncodeScalar(Element a) { - return ScalarEncoder(a); -} - -template -struct ScalarEncoder> { - cutlass::complex scalar; - - ScalarEncoder(cutlass::complex s): scalar(s) { } - - std::string str() const { - std::stringstream ss; - ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; - return ss.str(); - } -}; - -template -std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { - out << scalar.str(); - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { - switch (conv_op) { - case cutlass::conv::Operator::kFprop: return "fprop"; - case cutlass::conv::Operator::kDgrad: return "dgrad"; - case cutlass::conv::Operator::kWgrad: return "wgrad"; - case cutlass::conv::Operator::kDeconv: return "deconv"; - } - return "conv_unknown"; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Encode GemmCoord (Gemm problem size) -inline std::ostream &EncodeProblemSize( - std::ostream &out, - cutlass::gemm::GemmCoord const &problem) { - - out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Encode Conv2dProblemSize -inline std::ostream &EncodeProblemSize( - std::ostream &out, - cutlass::conv::Conv2dProblemSize const &problem) { - - out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" - << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; - - out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; - out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; - out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; - - switch (problem.mode) { - case cutlass::conv::Mode::kCrossCorrelation: - out << "corr"; - break; - case cutlass::conv::Mode::kConvolution: - out << "conv"; - break; - } - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Encode Conv3dProblemSize -inline std::ostream &EncodeProblemSize( - std::ostream &out, - cutlass::conv::Conv3dProblemSize const &problem) { - - out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" - << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; - - out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; - out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; - out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; - - switch (problem.mode) { - case cutlass::conv::Mode::kCrossCorrelation: - out << "corr"; - break; - case cutlass::conv::Mode::kConvolution: - out << "conv"; - break; - } - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Encode 3.x ConvNd ProblemShape -template -inline std::ostream &EncodeProblemSize( - std::ostream &out, - ProblemShape const& problem_shape) { - - out << problem_shape.shape_A << "_"; - out << problem_shape.shape_B << "_"; - - out << "padl" << problem_shape.lower_padding << "_"; - out << "padu" << problem_shape.upper_padding << "_"; - out << "str" << problem_shape.traversal_stride << "_"; - out << "dil" << problem_shape.dilation << "_"; - - switch (problem_shape.mode) { - case cutlass::conv::Mode::kCrossCorrelation: - out << "corr"; - break; - case cutlass::conv::Mode::kConvolution: - out << "conv"; - break; - } - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline std::string ElementTypeName() { - return std::string(typeid(Element).name()); -} - -template <> -inline std::string ElementTypeName() { - return "h"; -} - -template <> -inline std::string ElementTypeName>() { - return "ch"; -} - -template <> -inline std::string ElementTypeName() { - return "bf16"; -} - -template <> -inline std::string ElementTypeName>() { - return "cbf16"; -} - -template <> -inline std::string ElementTypeName() { - return "tf32"; -} - -template <> -inline std::string ElementTypeName>() { - return "ctf32"; -} - -template <> -inline std::string ElementTypeName>() { - return "c"; -} - -template <> -inline std::string ElementTypeName>() { - return "z"; -} - -template <> -inline std::string ElementTypeName>() { - return "q"; -} - -template <> -inline std::string ElementTypeName() { - return "s8"; -} - -template <> -inline std::string ElementTypeName() { - return "u8"; -} - -template <> -inline std::string ElementTypeName() { - return "s4"; -} - -template <> -inline std::string ElementTypeName() { - return "u4"; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline std::string LayoutTypeName() { - return std::string(typeid(Layout).name()); -} - -template <> -inline std::string LayoutTypeName() { - return "n"; -} - -template <> -inline std::string LayoutTypeName() { - return "t"; -} - -template <> -inline std::string LayoutTypeName() { - return "nhwc"; -} - -template <> -inline std::string LayoutTypeName>() { - return "nc32hw32"; -} - -template <> -inline std::string LayoutTypeName>() { - return "nc64hw64"; -} - -template <> -inline std::string LayoutTypeName>() { - return "c32rsk32"; -} - -template <> -inline std::string LayoutTypeName>() { - return "c64rsk64"; -} - -template <> -inline std::string LayoutTypeName() { - return "ndhwc"; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline std::string TensorTypeName() { - std::stringstream ss; - ss << ElementTypeName() << LayoutTypeName(); - return ss.str(); -} - -template -inline std::string TensorTypeName() { - std::stringstream ss; - ss << ElementTypeName(); - return ss.str(); -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Hash function on a byte array -struct CRC32 { - - uint32_t table[256]; - - // - // Methods - // - - CRC32() { - - uint32_t rem; - int i, j; - - for (i = 0; i < 256; i++) { - rem = i; - for (j = 0; j < 8; j++) { - if (rem & 1) { - rem >>= 1; - rem ^= 0xedb88320; - } else - rem >>= 1; - } - table[i] = rem; - } - } - - /// Computes the CRC of an array of bytes - uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { - uint8_t const *p = static_cast(start); - uint8_t const *q = static_cast(start) + length; - - crc = ~crc; - - for (; p != q; ++p) { - uint8_t octet = *p; - crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; - } - - return ~crc; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Element, typename Layout -> -uint32_t TensorHash( - cutlass::TensorView view, - CRC32 const &hash = CRC32(), - uint32_t crc = uint32_t() -) { - - return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); -} - -template -uint32_t TensorHash( - thrust::universal_vector& tensor, - CRC32 const &hash = CRC32(), - uint32_t crc = uint32_t() -) { - - return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline std::ostream &EncodeTypes( - std::ostream &out -) { - - out << TensorTypeName() << "_" - << TensorTypeName() << "_" - << TensorTypeName() << "_" - << ElementTypeName() << "_" - << ElementTypeName(); - - return out; -} - -template < - typename ElementA, - typename ElementB, - typename ElementC, - typename ElementD -> -inline std::ostream &EncodeTypes( - std::ostream &out -) { - - out << TensorTypeName() << "_" - << TensorTypeName() << "_" - << TensorTypeName() << "_" - << ElementTypeName(); - - return out; -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline CachedTestKey CreateCachedGemmTestKey( - cutlass::gemm::GemmCoord const &problem, - ElementCompute alpha, - ElementCompute beta, - cutlass::TensorView A, - cutlass::TensorView B, - cutlass::TensorView C -) { - - CachedTestKey key; - - // Encode gemm operator and problem sizes - key.op = "gemm"; - - std::stringstream ss_problem; - EncodeProblemSize(ss_problem, problem); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute>(ss_types); - key.types = ss_types.str(); - - // Encode hash for problem data - CRC32 crc_hash; - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline CachedTestKey CreateCachedConv2dTestKey( - - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem, - ElementCompute alpha, - ElementCompute beta, - cutlass::TensorView A, - cutlass::TensorView B, - cutlass::TensorView C -) { - - CachedTestKey key; - - // Encode conv2d operator and problem sizes - key.op = "conv2d"; - - std::stringstream ss_problem; - ss_problem << EncodeOperator(conv_operator) << "_"; - EncodeProblemSize(ss_problem, problem); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute>(ss_types); - key.types = ss_types.str(); - - // Encode hash for problem data - CRC32 crc_hash; - - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( - - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem, - ElementCompute alpha, - ElementCompute beta, - cutlass::TensorView A, - cutlass::TensorView B, - cutlass::TensorView C -) { - - CachedTestKey key; - - // Encode conv2d operator and problem sizes - key.op = "conv2d_with_broadcast"; - - std::stringstream ss_problem; - ss_problem << EncodeOperator(conv_operator) << "_"; - EncodeProblemSize(ss_problem, problem); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute>(ss_types); - key.types = ss_types.str(); - - // Encode hash for problem data - CRC32 crc_hash; - - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline CachedTestKey CreateCachedConv2dWithReductionTestKey( - - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem, - ElementCompute alpha, - ElementCompute beta, - cutlass::TensorView A, - cutlass::TensorView B, - cutlass::TensorView C -) { - - CachedTestKey key; - - // Encode conv2d operator and problem sizes - key.op = "conv2d_with_reduction"; - - std::stringstream ss_problem; - ss_problem << EncodeOperator(conv_operator) << "_"; - EncodeProblemSize(ss_problem, problem); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute>(ss_types); - key.types = ss_types.str(); - - // Encode hash for problem data - CRC32 crc_hash; - - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, typename LayoutA, - typename ElementB, typename LayoutB, - typename ElementC, typename LayoutC, - typename ElementAccumulator, - typename ElementCompute -> -inline CachedTestKey CreateCachedConv3dTestKey( - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv3dProblemSize const &problem, - ElementCompute alpha, - ElementCompute beta, - cutlass::TensorView A, - cutlass::TensorView B, - cutlass::TensorView C -) { - - CachedTestKey key; - - // Encode conv3d operator and problem sizes - key.op = "conv3d"; - - std::stringstream ss_problem; - - ss_problem << EncodeOperator(conv_operator) << "_"; - EncodeProblemSize(ss_problem, problem); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute>(ss_types); - key.types = ss_types.str(); - - // Encode problem data - CRC32 crc_hash; - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape, - typename ElementA, - typename ElementB, - typename ElementC, - typename ElementD -> -inline CachedTestKey CreateCachedConvNd3xTestKey( - cutlass::conv::Operator conv_operator, - ProblemShape const& problem_shape, - double alpha, - double beta, - thrust::universal_vector A, - thrust::universal_vector B, - thrust::universal_vector C -) { - - CachedTestKey key; - - // Encode convNd operator and problem sizes - std::stringstream ss_op; - ss_op << "conv" << ProblemShape::RankS << "d"; - key.op = ss_op.str(); - - std::stringstream ss_problem; - ss_problem << EncodeOperator(conv_operator) << "_"; - EncodeProblemSize(ss_problem, problem_shape); - ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); - key.problem = ss_problem.str(); - - // Encode problem data types - std::stringstream ss_types; - EncodeTypes< - ElementA, - ElementB, - ElementC, - ElementD>(ss_types); - key.types = ss_types.str(); - - // Encode problem data - CRC32 crc_hash; - key.A = TensorHash(A, crc_hash); - key.B = TensorHash(B, crc_hash); - key.C = TensorHash(C, crc_hash); - - return key; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace test::conv::device - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h deleted file mode 100644 index a14134b2854732e669977831207a456d28beed9f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h +++ /dev/null @@ -1,927 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed sizes for Conv2d problem -*/ -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" - -namespace test { -namespace conv { -namespace device { - -using Conv2dProblemVector = std::vector; - -// -// Structures to prune items from Conv2dProblemVector -// -// Specification template for pruning items for convolution problem lists -template struct Specification -{ - virtual ~Specification() = default; - virtual bool is_satisfied(T item) const = 0; -}; - -// input size (NHWC) specification -struct InputSizeSpecification : Specification -{ - cutlass::Tensor4DCoord input_size; - - InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} - - bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { - return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); - } -}; - -// stride (stride_h, stride_w) specification -struct StrideSpecification : Specification -{ - cutlass::MatrixCoord stride; - - StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} - - bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { - return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); - } -}; - -// channel (C,K) specification, must be multiple of minimum channel -struct ChannelDivisibilitySpecification : Specification -{ - int channel_multiple; - - ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} - - bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { - return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); - } -}; - -// -// Pruning function for items from Conv2dProblemVector based on a Specification -// -inline Conv2dProblemVector prune(Conv2dProblemVector const &items, - Specification const &spec) -{ - Conv2dProblemVector pruned_list; - - for (auto& p : items) - if (spec.is_satisfied(p)) - pruned_list.push_back(p); - return pruned_list; -} - - -//////////////////////////////////////////////////////////////////////////// -/// Structure TestbedConv2dProblemSizes initializes and holds conv default and -/// important network sizes -//////////////////////////////////////////////////////////////////////////// -struct TestbedConv2dProblemSizes { - - // - // Data members - // - int minimum_channel_size; - - Conv2dProblemVector conv2d_default_sizes; - Conv2dProblemVector conv2d_rigorous_sizes; - Conv2dProblemVector conv2d_resnet50_sizes; - Conv2dProblemVector conv2d_resnet50_sizes_perf; - - // - // Methods - // - /// Default ctor - TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { - initialize_conv2d_default_sizes(); - initialize_conv2d_rigorous_sizes(); - initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); - - initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); - filter_all(); - } - - /// Eliminates some illegal cases - void filter_all() { - - Conv2dProblemVector *problems_vectors[] = { - &conv2d_default_sizes, - &conv2d_rigorous_sizes, - &conv2d_resnet50_sizes, - &conv2d_resnet50_sizes_perf - }; - - for (Conv2dProblemVector *problems : problems_vectors) { - Conv2dProblemVector filtered; - - for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { - if (!(problem.C % minimum_channel_size)) { - filtered.push_back(problem); - } - } - - *problems = filtered; - } - } - - // Add a few standard convolution problem sizes - void initialize_conv2d_default_sizes() { - - //////////////////////////////////////////////////////////////////////////////////////////// - // Small input size x stride (1,1) - // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - //////////////////////////////////////////////////////////////////////////////////////////// - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 1, 1, minimum_channel_size}, // input size (NHWC) - {8, 1, 1, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 1, 8, minimum_channel_size}, // input size (NHWC) - {8, 1, 3, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 7, 8, minimum_channel_size}, // input size (NHWC) - {8, 3, 3, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 4, 4, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {2, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 5, 5, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 6, 5, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 6, 6, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 7, 7, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////////////// - // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0) - // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - //////////////////////////////////////////////////////////////////////////////////////////// - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 1, 1, minimum_channel_size}, // input size (NHWC) - {8, 1, 1, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 1, 8, minimum_channel_size}, // input size (NHWC) - {8, 1, 3, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 7, 8, minimum_channel_size}, // input size (NHWC) - {8, 3, 3, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 4, 4, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {2, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 5, 5, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 6, 5, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 6, 6, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 7, 9, minimum_channel_size}, // input size (NHWC) - {8, 7, 7, minimum_channel_size}, // filter size (KRSC) - {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////////////// - // Small input size x stride (2,2) - // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - //////////////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 11, 7, minimum_channel_size}, // input size (NHWC) - {8, 1, 1, minimum_channel_size}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 11, 7, minimum_channel_size}, // input size (NHWC) - {8, 3, 3, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 13, 11, minimum_channel_size}, // input size (NHWC) - {8, 1, 1, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 17, 19, minimum_channel_size}, // input size (NHWC) - {16, 2, 2, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 23, 5, minimum_channel_size}, // input size (NHWC) - {16, 3, 3, minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 13, 17, 8}, // input size (NHWC) - {24, 3, 3, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 23, 21, 8}, // input size (NHWC) - {24, 3, 3, 8}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {3, 3}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 20, 24, 8}, // input size (NHWC) - {40, 3, 3, 8}, // filter size (KRSC) - {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) - {3, 3}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 15, 19, 160}, // input size (NHWC) - {224, 1, 1, 160}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 19, 37, 160}, // input size (NHWC) - {224, 3, 3, 160}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 16, 16, 160}, // input size (NHWC) - {224, 2, 3, 160}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 23, 21, 128}, // input size (NHWC) - {224, 3, 3, 128}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 29, 37, 160}, // input size (NHWC) - {224, 5, 5, 160}, // filter size (KRSC) - {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) - {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) - {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 13, 16, 288}, // input size (NHWC) - {160, 5, 5, 288}, // filter size (KRSC) - {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 55, 51, 256}, // input size (NHWC) - {512, 1, 1, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 71, 80, 32}, // input size (NHWC) - {64, 5, 5, 32}, // filter size (KRSC) - {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 224, 224, 8}, // input size (NHWC) - {64, 7, 7, 8}, // filter size (KRSC) - {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // Medium input size stride (3, 3), filter (3, 3), non-default padding - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 27, 23, 256}, // input size (NHWC) - {512, 3, 3, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {3, 3}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // Medium input size padding > stride, asymmetric filter, padding and striding - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 27, 31, 256}, // input size (NHWC) - {512, 3, 3, 256}, // filter size (KRSC) - {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) - {3, 4}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 27, 35, 256}, // input size (NHWC) - {512, 7, 5, 256}, // filter size (KRSC) - {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) - {3, 5}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - //////////////////////////////////////////////////////////////////////////////////// - // Medium input size *mixed* stride (1, 2) and (2, 1), - // filter (3, 3), default padding - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 27, 27, 256}, // input size (NHWC) - {512, 3, 3, 256}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 27, 27, 256}, // input size (NHWC) - {512, 3, 3, 256}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - ///////////////////////////////////////////////////////////////////////////// - // Additional input size - ///////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {3, 28, 28, 256}, // input size (NHWC) - {256, 2, 2, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 32, 32, 16}, // input size (NHWC) - {32, 3, 3, 16}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {6, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {32, 24, 32, 32}, // input size (NHWC) - {32, 1, 2, 32}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {4, 4, 5, 128}, // input size (NHWC) - {256, 3, 6, 128}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - {4, 3, 3, 256} // output size (NPQK) - )); - - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {4, 2, 3, 256}, // input size (NHWC) - {328, 3, 5, 256}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - {4, 1, 1, 328} // output size (NPQK) - )); - } - - - // Add a few large and rigorous convolution problem sizes - void initialize_conv2d_rigorous_sizes() { - -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 124, 224, 96}, // input size (NHWC) - {24, 7, 7, 96}, // filter size (KRSC) - {1, 229, 129, 32} // output size (NPQK) - )); - - conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 233, 35, 48}, // input size (NHWC) - {24, 7, 5, 48}, // filter size (KRSC) - {1, 233, 35, 24} // output size (NPQK) - )); - -#endif - - } - - - // Add resent50 layers to unit testing sizes - void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ - -#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - [1, 224, 224, 3], // input size (NHWC) - [64, 7, 7, 3], // filter size (KRSC) - [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) - [2, 2], // stride (stride_h, stride_w) - [1, 1], // dilation (dilation_h, dilation_w) - )); -#endif - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 64}, // input size (NHWC) - {256, 1, 1, 64}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 64}, // input size (NHWC) - {64, 1, 1, 64}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 64}, // input size (NHWC) - {64, 3, 3, 64}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 256}, // input size (NHWC) - {64, 1, 1, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 256}, // input size (NHWC) - {512, 1, 1, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 56, 56, 256}, // input size (NHWC) - {128, 1, 1, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 28, 28, 128}, // input size (NHWC) - {128, 3, 3, 128}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 28, 28, 128}, // input size (NHWC) - {512, 1, 1, 128}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 28, 28, 512}, // input size (NHWC) - {128, 1, 1, 512}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 28, 28, 512}, // input size (NHWC) - {1024, 1, 1, 512}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 28, 28, 512}, // input size (NHWC) - {256, 1, 1, 512}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 14, 14, 256}, // input size (NHWC) - {256, 3, 3, 256}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 14, 14, 256}, // input size (NHWC) - {1024, 1, 1, 256}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 14, 14, 1024}, // input size (NHWC) - {256, 1, 1, 1024}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 14, 14, 1024}, // input size (NHWC) - {2048, 1, 1, 1024}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 14, 14, 1024}, // input size (NHWC) - {512, 1, 1, 1024}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 7, 7, 512}, // input size (NHWC) - {512, 3, 3, 512}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 7, 7, 512}, // input size (NHWC) - {2048, 1, 1, 512}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( - {batch_size, 7, 7, 2048}, // input size (NHWC) - {512, 1, 1, 2048}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - } - -}; - - -//////////////////////////////////////////////////////////////////////////// -/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and -/// important network sizes -//////////////////////////////////////////////////////////////////////////// -struct TestbedGroupConv2dProblemSizes { - - // - // Data members - // - int threadblock_n; - int threadblock_k; - int minimum_channel_size; - - Conv2dProblemVector default_single_group_sizes; - Conv2dProblemVector default_multiple_group_sizes; - - // - // Methods - // - /// Default ctor - TestbedGroupConv2dProblemSizes( - int threadblock_n_, - int threadblock_k_, - int minimum_channel_size_ = 64) - : threadblock_n (threadblock_n_), - threadblock_k (threadblock_k_), - minimum_channel_size (minimum_channel_size_) { - initialize_group_conv2d_default_sizes(); - filter_all(); - } - - /// Eliminates some illegal cases - void filter_all() { - - Conv2dProblemVector *problems_vectors[] = { - &default_single_group_sizes, - &default_multiple_group_sizes - }; - - for (Conv2dProblemVector *problems : problems_vectors) { - Conv2dProblemVector filtered; - - for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { - if (!((problem.C / problem.groups) % minimum_channel_size)) { - filtered.push_back(problem); - } - } - - *problems = filtered; - } - } - - // Add a few standard convolution problem sizes - void initialize_group_conv2d_default_sizes() { - - //////////////////////////////////////////////////////////////////////////////////// - // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 - // One CTA calculates a single group - //////////////////////////////////////////////////////////////////////////////////// - - for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { - // groups = 2, 3, 4 - for (int groups = 2; groups < 5; ++groups) { - - int conv_k = cta_per_group_k * threadblock_n * groups; - default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) - {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - groups // groups - )); - - } // loop groups - } // loop cta_per_group_k - - // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K - default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k}, // input size (NHWC) - {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 2 // groups - )); - - // Larger problem sizes - - default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 696}, // input size (NHWC) - {768, 3, 3, 232}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 3 // groups - )); - default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 14, 14, 1392}, // input size (NHWC) - {1536, 3, 3, 232}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 3 // groups - )); - - //////////////////////////////////////////////////////////////////////////////////// - // One CTA calculate multiple groups: CTA::N % k_per_group = 0 - //////////////////////////////////////////////////////////////////////////////////// - - // 2 groups per CTA - default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k * 4}, // input size (NHWC) - {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 2 // groups - )); - - // 2 groups per CTA and partial gemm_k - default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k}, // input size (NHWC) - {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 2 // groups - )); - - // 4 groups per CTA - default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k * 8}, // input size (NHWC) - {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 4 // groups - )); - - // 4 groups per CTA and partial gemm_k - default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, threadblock_k * 2}, // input size (NHWC) - {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - cutlass::conv::Mode::kCrossCorrelation, - 1, // split_k_slices - 4 // groups - )); - } - -}; - - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h deleted file mode 100644 index 34588ecb467b824cc0fcbbff0bc0d99e4385d80e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h +++ /dev/null @@ -1,818 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "conv2d_problems.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -template -class TestbedConv2d { -public: - - using ElementA = typename Conv2d::ElementA; - using LayoutA = typename Conv2d::LayoutA; - using ElementB = typename Conv2d::ElementB; - using LayoutB = typename Conv2d::LayoutB; - using ElementC = typename Conv2d::ElementC; - using LayoutC = typename Conv2d::LayoutC; - using ElementAccumulator = typename Conv2d::ElementAccumulator; - using ElementCompute = typename Conv2d::ElementCompute; - using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - - /// Reduction kernel - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, - typename EpilogueOutputOp::ElementAccumulator, - EpilogueOutputOp::kCount - >; - - using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< - cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, - EpilogueOutputOp, - ReductionOp - >; - - using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - using ReductionStrideIndex = typename ReductionDevice::StrideIndex; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - int tested_problem_count; - -public: - - TestbedConv2d( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope = 3; - } - else { - scope = 5; - } - } - else { - scope = 8; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv2dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - // increment tested problem count run by the testbed - tested_problem_count++; - -#if 0 // display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv2d conv2d_op; - - typename Conv2d::Arguments conv2d_args( - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_computed.device_ref(), - {alpha, beta}, - split_k_mode - ); - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // conv2d operation with parallel split-k-mode - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // conv2d output is written to workspace in global memory - conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); - // accumulate mma for each cta in k-dimension (1.0 * A * B) - conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; - // update conv2d operator arguments - status = conv2d_op.update(conv2d_args, workspace.get()); - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run conv2d operator - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run." << std::endl; - return false; - } - - - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // configure parallel reduction operator - ReductionDevice reduction_op; - - typename ReductionDevice::Arguments reduction_args( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), - problem_size.split_k_slices, - cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - { - reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C - {alpha, beta} - ); - - status = reduction_op.initialize(reduction_args, nullptr); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run prallel reduction kernel - status = reduction_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - } - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - tensor_D_computed.sync_host(); - - // - // Reference check - support caching results - // - - CachedTestKey cached_test_key = CreateCachedConv2dTestKey< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - alpha, - beta, - tensor_A.host_view(), - tensor_B.host_view(), - tensor_C.host_view() - ); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string conv2d_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - } - - if (!cached_result_loaded) { - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), - alpha, - beta); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); - -#else - - cutlass::reference::host::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), - alpha, - beta); - -#endif - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); - } - } // if (!cached_result_loaded) - - uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - passed = (tensor_D_hash == cached_test_result.D); - - EXPECT_EQ(tensor_D_hash, cached_test_result.D) - << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; - } - else { - - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - } - - EXPECT_TRUE(passed); - - std::stringstream ss_problem_size_text; - ss_problem_size_text << "nhwc_" - << problem_size.N << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_krsc_" - << problem_size.K << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv2d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) - << ss_problem_size_text.str() - << Conv2d::ThreadblockShape::kM << "x" - << Conv2d::ThreadblockShape::kN << "x" - << Conv2d::ThreadblockShape::kK << "_" - << Conv2d::WarpShape::kM << "x" - << Conv2d::WarpShape::kN << "x" - << Conv2d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n"; - - results << "\nD reference (hash: " << cached_test_result.D << ")\n"; - - if (!cached_result_loaded) { - results - << tensor_D_reference.host_view() << "\n"; - } - - results - << "\nD computed (hash: " << tensor_D_hash << ")\n" - << tensor_D_computed.host_view() << "\n"; - - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestSpecificConv2d( - const Conv2dProblemVector & problem_sizes) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv2d testbed; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for(auto conv_problem : problem_sizes) { - - // - // Test - // - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - - return true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestAllConv2d( - const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), - const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv2d testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv2d problem sizes to avoid duplicate runs - Conv2dProblemVector conv_tested_sizes; - - // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) - std::vector problem_vectors = { - conv_test_sizes, // run user specified sizes - conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -#endif - }; - - // Flatten 2D problem_vectors into a 1D problem_sizes - std::vector problem_sizes; - for (auto problem_vector : problem_vectors) { - for(auto conv_problem : problem_vector) { - problem_sizes.push_back(conv_problem); - } - } - - // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) - // run the most rigorous problem size first - if (CutlassUnitTestProblemCount()) { - std::reverse(problem_sizes.begin(), problem_sizes.end()); - } - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for(auto conv_problem : problem_sizes) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - - // Fixed channels algorithm requires channel count to match access size - if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == - cutlass::conv::IteratorAlgorithm::kFixedChannels) { - if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { - continue; - } - } - - // Few channels algorithm requires channel count to match access size - if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == - cutlass::conv::IteratorAlgorithm::kFewChannels) { - if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { - continue; - } - } - - // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} - // Although strided dgrad works for all stride combinations, we are only going - // to run strided dgrad for non-unity strides - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts - if (CutlassUnitTestProblemCount() && - testbed.tested_problem_count > CutlassUnitTestProblemCount()) { - return true; - } - } - - // Small-channels convolution can't run here. - if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == - cutlass::conv::IteratorAlgorithm::kFixedChannels) { - - return true; - } - - // Small-channels convolution can't run here. - if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == - cutlass::conv::IteratorAlgorithm::kFewChannels) { - - return true; - } - - // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - - passed = testbed.run( - cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1}), // dilation (dilation_h, dilation_w) - cutlass::conv::SplitKMode::kSerial, - cutlass::from_real(2.0), - cutlass::from_real(2.0)); - - passed = testbed.run( - cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}) // dilation (dilation_h, dilation_w) - .reset_split_k_slices(2), - cutlass::conv::SplitKMode::kSerial, - cutlass::from_real(2.0), - cutlass::from_real(2.0)); - - if (!passed) { - return false; - } - - return passed; - } - // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( - {1, 17, 11, 288}, // input size (NHWC) - {160, 3, 3, 288}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - ); - - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial, - cutlass::conv::SplitKMode::kParallel, - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - - // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts - if (CutlassUnitTestProblemCount() && - testbed.tested_problem_count > CutlassUnitTestProblemCount()) { - return true; - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h deleted file mode 100644 index cf075674da673cf8e056172732f912b8acba3c5b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h +++ /dev/null @@ -1,666 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "conv2d_problems.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/host_reorder.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -template -class InterleavedTestbedConv2d { -public: - - using ElementA = typename Conv2d::ElementA; - using LayoutA = typename Conv2d::LayoutA; - using ElementB = typename Conv2d::ElementB; - using LayoutB = typename Conv2d::LayoutB; - using ElementC = typename Conv2d::ElementC; - using LayoutC = typename Conv2d::LayoutC; - using ElementAccumulator = typename Conv2d::ElementAccumulator; - using ElementCompute = typename Conv2d::ElementCompute; - using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - - /// Reduction kernel - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, - typename EpilogueOutputOp::ElementAccumulator, - EpilogueOutputOp::kCount - >; - - using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< - cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, - EpilogueOutputOp, - ReductionOp - >; - - using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - using ReductionStrideIndex = typename ReductionDevice::StrideIndex; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_B_reordered; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - -public: - - InterleavedTestbedConv2d( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - scope = 3; - } - else { - scope = 8; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - - cutlass::reorder_convK( - tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_B_reordered.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerMultiprocessor < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv2dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 //display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv2d conv2d_op; - - typename Conv2d::Arguments conv2d_args( - problem_size, - tensor_A.device_ref(), - tensor_B_reordered.device_ref(), - tensor_C.device_ref(), - tensor_D_computed.device_ref(), - {alpha, beta}, - split_k_mode - ); - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); - - // conv2d operation with parallel split-k-mode - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // conv2d output is written to workspace in global memory - conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); - // accumulate mma for each cta in k-dimension (1.0 * A * B) - conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; - // update conv2d operator arguments - status = conv2d_op.update(conv2d_args, workspace.get()); - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run conv2d operator - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // configure parallel reduction operator - ReductionDevice reduction_op; - - typename ReductionDevice::Arguments reduction_args( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), - problem_size.split_k_slices, - cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - { - reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) - }, - // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C - {alpha, beta} - ); - - status = reduction_op.initialize(reduction_args, nullptr); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run prallel reduction kernel - status = reduction_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - } - bool passed = false; - - tensor_D_computed.sync_host(); - - // - // Reference check - support caching results - // - - CachedTestKey cached_test_key = CreateCachedConv2dTestKey< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - alpha, - beta, - tensor_A.host_view(), - tensor_B.host_view(), - tensor_C.host_view() - ); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string conv2d_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - } - - if (!cached_result_loaded) { - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - cutlass::NumericConverterClamp - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), - alpha, - beta); - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); - -#else - - cutlass::reference::host::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ElementC, - cutlass::NumericConverterClamp - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), - alpha, - beta); - -#endif - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); - } - } // if (!cached_result_loaded) - - uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - passed = (tensor_D_hash == cached_test_result.D); - - EXPECT_EQ(tensor_D_hash, cached_test_result.D) - << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; - } - else { - - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - } - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv2d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) - << "ncxhwx_" - << problem_size.N << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_cxrskx_" - << problem_size.K << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") - << Conv2d::ThreadblockShape::kM << "x" - << Conv2d::ThreadblockShape::kN << "x" - << Conv2d::ThreadblockShape::kK << "_" - << Conv2d::WarpShape::kM << "x" - << Conv2d::WarpShape::kN << "x" - << Conv2d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n"; - - results << "\nD reference (hash: " << cached_test_result.D << ")\n"; - - if (!cached_result_loaded) { - results - << tensor_D_reference.host_view() << "\n"; - } - - results - << "\nD computed (hash: " << tensor_D_hash << ")\n" - << tensor_D_computed.host_view() << "\n"; - - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestAllInterleavedConv2d( - const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), - const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { - - bool passed = true; - - // - // Testbed object - // - - InterleavedTestbedConv2d testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout - - // Vector of conv2d problem sizes to avoid duplicate runs - Conv2dProblemVector conv_tested_sizes; - - Conv2dProblemVector const *problem_vectors[] = { - &conv_test_sizes, // run user specified sizes - &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -#endif - }; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv2dProblemVector const * problem_vector : problem_vectors) { - - ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK - auto pruned_problem_vector = prune(*problem_vector, channel_spec); - - // Run conv testbed on default convolution sizes - for(auto conv_problem : pruned_problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - } - -#if 0 - // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( - {1, 17, 11, 288}, // input size (NHWC) - {160, 3, 3, 288}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - ); - - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial, - cutlass::conv::SplitKMode::kParallel, - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - } - } - } - } -#endif - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h deleted file mode 100644 index ad7b2ce61a66a79f852c0aac0895d10ba18e5466..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h +++ /dev/null @@ -1,622 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling -*/ - -#pragma once - -#include -#include -#include - -#include "conv2d_problems.h" -#include "../../common/cutlass_unit_test.h" -#include "../../gemm/device/testbed_utils.h" - -#include "cutlass/matrix_coord.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/layout/matrix.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_reduce.h" - -namespace test { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Conv, - template class ActivationFunctor -> -struct TestbedConv2dWithAbsMax { - - using ElementAccumulator = typename Conv::ElementAccumulator; - using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute; - using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor; - using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax; - static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator; - - static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; - static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; - bool doScaleA; - bool doScaleB; - bool doScaleC; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_Aux; - cutlass::HostTensor tensor_D; - cutlass::HostTensor tensor_Vector; - cutlass::HostTensor tmp_D; - cutlass::HostTensor reference_D; - cutlass::HostTensor reference_Aux; - cutlass::HostTensor scale_A; - cutlass::HostTensor scale_B; - cutlass::HostTensor scale_C; - cutlass::HostTensor scale_D; - cutlass::HostTensor scale_Aux; - cutlass::HostTensor abs_max_Aux; - cutlass::HostTensor abs_max_D; - cutlass::HostTensor reference_abs_max_Aux; - cutlass::HostTensor reference_abs_max_D; - - // - // Methods - // - - TestbedConv2dWithAbsMax( - bool scaleA = true, - bool scaleB = true, - bool scaleC = true, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize scaling factors - template - bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { - cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); - return true; - } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) { - // - // Allocate the GEMM workspace - // - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()}); - reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); - tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - cutlass::Coord<4> origin(0); - tensor_A.host_view().at(origin) = typename Conv::ElementA(1); - tensor_B.host_view().at(origin) = typename Conv::ElementB(1); - tensor_C.host_view().at(origin) = typename Conv::ElementC(1); - tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1); - - cutlass::reference::host::TensorFill(tensor_D.host_view()); - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - tensor_Vector.sync_device(); - - int scale_bits = 2; - if (doScaleA) { - scale_A.resize({1, 1, 1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits)); - scale_A.sync_device(); - } - - if (doScaleB) { - scale_B.resize({1, 1, 1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits)); - scale_B.sync_device(); - } - - if (doScaleC) { - scale_C.resize({1, 1, 1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits)); - scale_C.sync_device(); - } - - if (kScaleOutput) { - scale_D.resize({1, 1, 1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits)); - scale_D.sync_device(); - - abs_max_D.resize({1, 1, 1, 1}); - cutlass::reference::host::TensorFill(abs_max_D.host_view()); - abs_max_D.sync_device(); - - reference_abs_max_D.resize({1, 1, 1, 1}); - } - - if (kScaleAux) { - tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - cutlass::reference::host::TensorFill(tensor_Aux.host_view()); - tensor_Aux.sync_device(); - - scale_Aux.resize({1, 1, 1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits)); - scale_Aux.sync_device(); - - abs_max_Aux.resize({1, 1, 1, 1}); - cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); - abs_max_Aux.sync_device(); - - reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); - reference_abs_max_Aux.resize({1, 1, 1, 1}); - } - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::conv::Conv2dProblemSize const &problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); - - if (kScaleAux) { - tensor_Aux.sync_host(); - abs_max_Aux.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); - passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); - passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); - } - - if (kScaleOutput) { - abs_max_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); - passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); - } - - EXPECT_TRUE(passed) << " mismatched reference"; - - if (!passed) { - - std::ofstream file0("conv_testbed_with_amax_errors_reference.txt"); - std::ofstream file1("conv_testbed_with_amax_errors_computed.txt"); - - std::ofstream file("conv_testbed_with_amax_errors.txt"); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\nVector =\n" << tensor_Vector.host_view() - << "\nScaleA = " << scale_A.host_view() - << "\nScaleB = " << scale_B.host_view() - << "\nScaleC = " << scale_C.host_view() - << "\nScaleD = " << scale_D.host_view() - << "\nScaleAux = " << scale_Aux.host_view() - << std::endl; - - file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl; - file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl; - if (kScaleAux) { - file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl; - file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl; - file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl; - file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl; - } - if (kScaleOutput) { - file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl; - file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl; - } - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::conv::Conv2dProblemSize const &problem_size, - ElementCompute alpha, - ElementCompute beta) { - - cutlass::Coord<4> origin(0); - ElementCompute scaled_alpha = alpha; - if (doScaleA) { - scaled_alpha *= scale_A.host_view().at(origin); - } - if (doScaleB) { - scaled_alpha *= scale_B.host_view().at(origin); - } - - ElementCompute scaled_beta = beta; - if (doScaleC) { - scaled_beta *= scale_C.host_view().at(origin); - } - - // - // Verify - // - - cutlass::reference::host::Conv2d< - typename Conv::ElementA, typename Conv::LayoutA, - typename Conv::ElementB, typename Conv::LayoutB, - typename Conv::ElementC, typename Conv::LayoutC, - ElementCompute, ElementAccumulator, ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tmp_D.host_ref(), - scaled_alpha, - scaled_beta - ); - - ElementCompute tmp_abs_max_Aux(0.); - ElementCompute tmp_abs_max_D(0.); - - cutlass::NumericConverter cvt_c_to_compute; - cutlass::NumericConverter cvt_accum_to_compute; - cutlass::NumericConverter cvt_compute_to_absmax; - cutlass::NumericConverter cvt_compute_to_d; - cutlass::NumericConverter cvt_compute_to_aux; - - cutlass::absolute_value_op abs; - cutlass::maximum_with_nan_propogation max; - ActivationFunctor act; - - ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); - - for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { - ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k})); - ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k})); - ElementCompute aux = intermediate + bias; - ElementCompute d = act(aux); - tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); - tmp_abs_max_D = max(abs(d), tmp_abs_max_D); - reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale); - - if (kScaleAux) { - reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); - } - } - } - } - } - if (kScaleAux) { - reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); - } - - if (kScaleOutput) { - reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv2dProblemSize const &problem_size, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) - { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; - typename Conv::EpilogueOutputOp::Params epilogue_params{ - activation_params, - scale_A.device_data(), - scale_B.device_data(), - scale_C.device_data(), - scale_D.device_data(), - scale_Aux.device_data(), - abs_max_Aux.device_data(), - abs_max_D.device_data() - }; - - typename Conv::Arguments arguments{ - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), - tensor_Aux.device_ref(), - epilogue_params, - cutlass::conv::SplitKMode::kSerial, - tensor_Vector.device_data(), - 0 - }; - - Conv conv2d_op; - - cutlass::Status status = conv2d_op.can_implement(arguments); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - size_t workspace_size = Conv::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - status = conv2d_op.initialize(arguments, workspace.get()); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - cudaError_t cuda_error = cudaDeviceSynchronize(); - EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Failed" << std::endl; - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ImplicitGemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity -> -bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { - const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(); - const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector(); - - // - // Testbed object - // - - TestbedConv2dWithAbsMax testbed(scaleA, scaleB, scaleC); - - // - // Get conv problem sizes to run conv operator - // - TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv2d problem sizes to avoid duplicate runs - Conv2dProblemVector conv_tested_sizes; - - Conv2dProblemVector const *problem_vectors[] = { - &conv_test_sizes, // run user specified sizes - &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -#endif - }; - - bool passed = true; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv2dProblemVector const * problem_vector : problem_vectors) { - - // Prune all problems with channels that aren't divisible by the number of elements accessed per - // load for operands A and B. This is meant to align with the requirements of iterators used for - // fprop kernels. - ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits::value); - auto pruned_problem_vector = prune(*problem_vector, channel_spec); - - // Run conv testbed on default convolution sizes - for(auto conv_problem : pruned_problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed &= testbed.run(conv_problem); - - if (!passed) { - return false; - } - - // test mode = convolution - passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution)); - - if (!passed) { - return false; - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h deleted file mode 100644 index f768f5b25f425910a49058599d3854352136caef..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ /dev/null @@ -1,734 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM for fused epilogue broadcast testbed - - Parallel split-k is not tested because we can just use regular conv kernel - when we need to use parallel-splitk. Broadcast can happen in the reduction - kernel. -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "conv2d_problems.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Conv2dWithBroadcastReferenceOp { - - using OutputOp = typename Conv2d::EpilogueOutputOp; - - using ElementCompute = typename OutputOp::ElementCompute; - using ElementZ = typename OutputOp::ElementZ; - using ElementT = typename OutputOp::ElementT; - - typename OutputOp::BinaryOp binary_op; - typename OutputOp::ElementwiseOp elementwise_op; - - Conv2dWithBroadcastReferenceOp() { } - - void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { - ElementCompute t_full = binary_op(conv2d, bias); - T = ElementT(t_full); - - ElementCompute z_full = elementwise_op(t_full); - Z = ElementZ(z_full); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Fused testbed -// -// Y = CONV(AB, C) -// -// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) -// -// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) -// - -template < - typename Conv2d, - typename ReferenceOp, - bool AddBroadcastFirst = false -> -class TestbedConv2dWithBroadcast { -public: - - using ElementA = typename Conv2d::ElementA; - using LayoutA = typename Conv2d::LayoutA; - using ElementB = typename Conv2d::ElementB; - using LayoutB = typename Conv2d::LayoutB; - using ElementC = typename Conv2d::ElementC; - using LayoutC = typename Conv2d::LayoutC; - using ElementAccumulator = typename Conv2d::ElementAccumulator; - using ElementCompute = typename Conv2d::ElementCompute; - using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; - using ElementZ = typename EpilogueOutputOp::ElementZ; - using ElementT = typename EpilogueOutputOp::ElementT; - using ElementVector = typename EpilogueOutputOp::ElementVector; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - static const bool kAddBroadcastFirst = AddBroadcastFirst; - static const bool kStoreT = EpilogueOutputOp::kStoreT; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_C_reference; - cutlass::HostTensor tensor_Z_computed; - cutlass::HostTensor tensor_Z_reference; - cutlass::HostTensor tensor_T_computed; - cutlass::HostTensor tensor_T_reference; - cutlass::HostTensor tensor_Y_reference; - cutlass::HostTensor tensor_Broadcast; // Input Broadcast - -public: - - TestbedConv2dWithBroadcast( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope = 3; - } - else { - scope = 5; - } - } - else { - scope = 8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Broadcast.resize({ - 1, - 1, - 1, - implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), - }); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); - - for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { - for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { - for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { - for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { - tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); - } - } - } - } - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_Broadcast.sync_device(); - tensor_C_reference.sync_device(); - tensor_Z_computed.sync_device(); - tensor_Z_reference.sync_device(); - tensor_T_computed.sync_device(); - tensor_T_reference.sync_device(); - tensor_Y_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv2dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(1)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 //display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv2d conv2d_op; - typename Conv2d::Arguments conv2d_args( - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_Z_computed.device_ref(), - {alpha, beta}, - split_k_mode, - tensor_Broadcast.device_data(), - kStoreT ? tensor_T_computed.device_data() : nullptr, - 0, // This must be zero - implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() - ); - - // initialize the kernel - size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // run conv2d operator - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - tensor_T_computed.sync_host(); - tensor_Z_computed.sync_host(); - - // - // Reference check - // - - // When kAddBroadcastFirst is true, add bias on the host - ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementAccumulator, - LayoutC, - ElementAccumulator, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C_reference.device_ref(), - tensor_Y_reference.device_ref(), - alpha, - beta_ref); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_Y_reference.sync_host(); - -#else - - cutlass::reference::host::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementAccumulator, - LayoutC, - ElementAccumulator, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C_reference.host_ref(), - tensor_Y_reference.host_ref(), - alpha, - beta_ref); - -#endif - ReferenceOp reference_op; - - // compute tensor Z and tensor T - for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { - for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { - for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { - - ElementZ z{}; - ElementT t{}; - - ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); - ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); - - - if (kAddBroadcastFirst) { - reference_op(z, t, accum + bias, - beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); - } else { - reference_op(z, t, accum, bias); - } - - tensor_Z_reference.at({n, p, q, k}) = z; - tensor_T_reference.at({n, p, q, k}) = t; - } - } - } - } - - if (kStoreT) { - passed = cutlass::reference::host::TensorEquals( - tensor_T_computed.host_view(), - tensor_T_reference.host_view()); - - EXPECT_TRUE(passed); - } - - passed = cutlass::reference::host::TensorEquals( - tensor_Z_computed.host_view(), - tensor_Z_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv2d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) - << "nhwc_" - << problem_size.N << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_krsc_" - << problem_size.K << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") - << Conv2d::ThreadblockShape::kM << "x" - << Conv2d::ThreadblockShape::kN << "x" - << Conv2d::ThreadblockShape::kK << "_" - << Conv2d::WarpShape::kM << "x" - << Conv2d::WarpShape::kN << "x" - << Conv2d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" - << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" - << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" - << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" - << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" - << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; - } - - return passed; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template , - bool AddBroadcastFirst = false> -bool TestSpecificConv2dWithBroadcast( - const Conv2dProblemVector & problem_sizes) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv2dWithBroadcast testbed; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for(auto conv_problem : problem_sizes) { - - // - // Test - // - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - - return true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template , - bool AddBroadcastFirst = false, - bool TestSplitK = true -> -bool TestAllConv2dWithBroadcast( - const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), - const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv2dWithBroadcast testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv2d problem sizes to avoid duplicate runs - Conv2dProblemVector conv_tested_sizes; - - Conv2dProblemVector const *problem_vectors[] = { - &conv_test_sizes, // run user specified sizes - &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -#endif - }; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv2dProblemVector const * problem_vector : problem_vectors) { - - // Run conv testbed on default convolution sizes - for(auto conv_problem : *problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - -#if 0 // relax restrictions on analytic strided dgrad - // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } -#endif - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - } - - // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - - passed = testbed.run( - cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1}), // dilation (dilation_h, dilation_w) - cutlass::conv::SplitKMode::kSerial, - cutlass::from_real(2.0), - cutlass::from_real(2.0)); - - if (!passed) { - return false; - } - - return passed; - } - - if (!TestSplitK) - return passed; - - // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( - {1, 17, 11, 288}, // input size (NHWC) - {160, 3, 3, 288}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - ); - - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h deleted file mode 100644 index a8ec16ca5de369470f5dc50bb6f8b5e2da3da10d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h +++ /dev/null @@ -1,643 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/tensor_reduce.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "conv2d_problems.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -template -class TestbedConv2dWithReduction { -public: - - using ElementA = typename Conv2d::ElementA; - using LayoutA = typename Conv2d::LayoutA; - using ElementB = typename Conv2d::ElementB; - using LayoutB = typename Conv2d::LayoutB; - using ElementC = typename Conv2d::ElementC; - using LayoutC = typename Conv2d::LayoutC; - using ElementAccumulator = typename Conv2d::ElementAccumulator; - using ElementCompute = typename Conv2d::ElementCompute; - using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; - using ElementT = typename EpilogueOutputOp::ElementTensor; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - - cutlass::HostTensor tensor_Reduction; - cutlass::HostTensor tensor_Tensor; - cutlass::HostTensor tensor_Final_Reduction; - - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - -public: - - TestbedConv2dWithReduction( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope = 2; - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - tensor_Reduction.resize({ - 1, - 1, - (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, - (problem_size.K) - }); - - tensor_Final_Reduction.resize({ - 1, - 1, - 1, - (problem_size.K) - }); - - tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); - - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv2dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 //display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv2d conv2d_op; - - typename Conv2d::Arguments conv2d_args( - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_computed.device_ref(), - {alpha, beta}, - split_k_mode, - tensor_Reduction.device_data(), - tensor_Tensor.device_data(), - static_cast(tensor_Reduction.stride()[0]), - static_cast(tensor_Tensor.stride()[0]) - ); - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // conv2d operation with parallel split-k-mode - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // conv2d output is written to workspace in global memory - conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); - // accumulate mma for each cta in k-dimension (1.0 * A * B) - conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; - // update conv2d operator arguments - status = conv2d_op.update(conv2d_args, workspace.get()); - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run conv2d operator - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - // Final reduction over the partial reduction tensor - using Functor = cutlass::plus; - using TensorReduction = cutlass::reduction::device::TensorReduction< - ElementAccumulator, - ElementAccumulator, - LayoutC, - Functor, - 8, - ElementAccumulator - >; - - TensorReduction reduction(tensor_Reduction.extent(), 2); - - cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); - - status = reduction.reduce( - tensor_Final_Reduction.device_ref(), - tensor_Reduction.device_ref(), - reduction_device_workspace.get(), - ElementAccumulator()); - - EXPECT_EQ(status, cutlass::Status::kSuccess); - EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); - - // - // Reference check - // - - tensor_D_computed.sync_host(); - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), - alpha, - beta); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); - -#else - - cutlass::reference::host::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), - alpha, - beta); - -#endif - - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - - EXPECT_TRUE(passed); - - // - // Reference check on reduction results - // - - tensor_Reduction.sync_host(); - tensor_Final_Reduction.sync_host(); - - // compute backwards for reduction results - cutlass::HostTensor reference_Reduction; - reference_Reduction.resize({ - 1, - 1, - 1, - (problem_size.K) - }); - - for (int k = 0; k < problem_size.K; ++k) { - ElementAccumulator reduced_value = ElementAccumulator(); - for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - reduced_value += tensor_D_reference.at({n, p, q, k}); - } - } - } - reference_Reduction.at({0, 0, 0, k}) = reduced_value; - } - - passed = cutlass::reference::host::TensorEquals( - tensor_Final_Reduction.host_view(), - reference_Reduction.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv2d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) - << "nhwc_" - << problem_size.N << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_krsc_" - << problem_size.K << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") - << Conv2d::ThreadblockShape::kM << "x" - << Conv2d::ThreadblockShape::kN << "x" - << Conv2d::ThreadblockShape::kK << "_" - << Conv2d::WarpShape::kM << "x" - << Conv2d::WarpShape::kN << "x" - << Conv2d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" - << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" - << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" - << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestAllConv2dWithReduction( - const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), - const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv2dWithReduction testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv2d problem sizes to avoid duplicate runs - Conv2dProblemVector conv_tested_sizes; - - Conv2dProblemVector const *problem_vectors[] = { - &conv_test_sizes, // run user specified sizes - &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -#endif - }; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv2dProblemVector const * problem_vector : problem_vectors) { - - // Run conv testbed on default convolution sizes - for(auto conv_problem : *problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - -#if 0 // relax restrictions on analytic strided dgrad - // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } -#endif - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - } - - // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - - passed = testbed.run( - cutlass::conv::Conv2dProblemSize( - {1, 56, 56, 8}, // input size (NHWC) - {8, 1, 1, 8}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1}), // dilation (dilation_h, dilation_w) - cutlass::conv::SplitKMode::kSerial, - cutlass::from_real(2.0), - cutlass::from_real(2.0)); - - if (!passed) { - return false; - } - - return passed; - } - - // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( - {1, 17, 11, 288}, // input size (NHWC) - {160, 3, 3, 288}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - ); - - // Parallel SplitK is not tested. - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial, - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h deleted file mode 100644 index fae7d6194fb671594221a90faea7cac1e5fbeb9f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h +++ /dev/null @@ -1,293 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed sizes for Conv2d problem -*/ -#pragma once - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/cutlass.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_types.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" - -namespace test { -namespace conv { -namespace device { - -using Conv3dProblemVector = std::vector; - -//////////////////////////////////////////////////////////////////////////// -/// Structure TestbedConv3dProblemSizes initializes and holds conv default and -/// important network sizes -//////////////////////////////////////////////////////////////////////////// -struct TestbedConv3dProblemSizes { - - // - // Data members - // - int minimum_channel_size; - Conv3dProblemVector conv3d_default_sizes; - Conv3dProblemVector conv3d_vnet_medical_sizes; - - // - // Methods - // - /// Default ctor - TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { - - initialize_conv3d_default_sizes(); - initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); - - filter_all(); - } - - /// Eliminates some illegal cases - void filter_all() { - - Conv3dProblemVector *problems_vectors[] = { - &conv3d_default_sizes, - &conv3d_vnet_medical_sizes - }; - - for (Conv3dProblemVector *problems : problems_vectors) { - Conv3dProblemVector filtered; - - for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { - if (!(problem.C % minimum_channel_size)) { - filtered.push_back(problem); - } - } - - *problems = filtered; - } - } - - // Add a few standard convolution problem sizes - void initialize_conv3d_default_sizes() { - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) - {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) - {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) - {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) - CUTLASS_STL_NAMESPACE::make_tuple( - cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) - ), - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) - {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) - {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) - CUTLASS_STL_NAMESPACE::make_tuple( - cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) - ), - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) - {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 1, 15, 19, 160}, // input size (NDHWC) - {224, 1, 3, 6, 160}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) - {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) - {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( - {1, 11, 15, 19, 64}, // input size (NDHWC) - {32, 4, 3, 6, 64}, // filter size (KTRSC) - cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - } - - // Add vnet layers to unit testing sizes - void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 32, 32, 32, 16}, // input size (NDHWC) - {32, 2, 2, 2, 16}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 16, 16, 16, 32}, // input size (NDHWC) - {32, 3, 3, 3, 32}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 16, 16, 16, 32}, // input size (NDHWC) - {64, 2, 2, 2, 32}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 8, 8, 8, 64}, // input size (NDHWC) - {64, 3, 3, 3, 64}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 8, 8, 8, 64}, // input size (NDHWC) - {128, 2, 2, 2, 64}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 4, 4, 4, 128}, // input size (NDHWC) - {128, 3, 3, 3, 128}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 8, 8, 8, 128}, // input size (NDHWC) - {128, 3, 3, 3, 128}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 16, 16, 16, 64}, // input size (NDHWC) - {64, 3, 3, 3, 64}, // filter size (KTRSC) - cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 32, 32, 32, 16}, // input size (NDHWC) - {64, 2, 2, 2, 16}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - - conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( - {batch_size, 16, 16, 16, 32}, // input size (NDHWC) - {128, 2, 2, 2, 32}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - )); - - } - -}; - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h deleted file mode 100644 index 029f5effb9103bebd4ee61767795d3883541d986..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h +++ /dev/null @@ -1,716 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "cutlass/util/reference/host/tensor_fill.h" - -#include "cutlass/util/reference/host/convolution.h" - -#include "cutlass/util/reference/host/tensor_compare.h" - -#include "cutlass/util/reference/device/convolution.h" -#include "cutlass/util/reference/device/tensor_compare.h" - -#include "conv3d_problems.h" -#include "cutlass/core_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -template -class TestbedConv3d { -public: - - using ElementA = typename Conv3d::ElementA; - using LayoutA = typename Conv3d::LayoutA; - using ElementB = typename Conv3d::ElementB; - using LayoutB = typename Conv3d::LayoutB; - using ElementC = typename Conv3d::ElementC; - using LayoutC = typename Conv3d::LayoutC; - using ElementAccumulator = typename Conv3d::ElementAccumulator; - using ElementCompute = typename Conv3d::ElementCompute; - using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; - - /// Reduction kernel - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, - typename EpilogueOutputOp::ElementAccumulator, - EpilogueOutputOp::kCount - >; - - using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< - cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, - EpilogueOutputOp, - ReductionOp - >; - - using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - using ReductionStrideIndex = typename ReductionDevice::StrideIndex; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - -public: - - TestbedConv3d( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - scope = 4; - } - else { - scope = 8; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { - - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - - /// Executes one test - bool run( - cutlass::conv::Conv3dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute()) { - - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 //display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv3d conv3d_op; - - typename Conv3d::Arguments conv3d_args( - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_computed.device_ref(), - {alpha, beta}, - split_k_mode - ); - - cutlass::Status status = conv3d_op.can_implement(conv3d_args); - if (status != cutlass::Status::kSuccess) { - std::cerr << "can_implement failed for the given problem_size: \n"; - return false; - } - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - status = conv3d_op.initialize(conv3d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // conv3d operation with parallel split-k-mode - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // conv3d output is written to workspace in global memory - conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); - // accumulate mma for each cta in k-dimension (1.0 * A * B) - conv3d_args.output_op = {1.0, 0.0}; - // update conv3d operator arguments - status = conv3d_op.update(conv3d_args, workspace.get()); - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run conv3d operator - status = conv3d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // configure parallel reduction operator - ReductionDevice reduction_op; - - typename ReductionDevice::Arguments reduction_args( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), - problem_size.split_k_slices, - cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - { - reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) - }, - { - tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) - }, - // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C - {alpha, beta} - ); - - status = reduction_op.initialize(reduction_args, nullptr); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - // run prallel reduction kernel - status = reduction_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - } - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - tensor_D_computed.sync_host(); - - // - // Reference check - support caching results - // - - CachedTestKey cached_test_key = CreateCachedConv3dTestKey< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - alpha, - beta, - tensor_A.host_view(), - tensor_B.host_view(), - tensor_C.host_view() - ); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string conv3d_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - CachedTestResultListing cached_results(conv3d_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - } - - if (!cached_result_loaded) { - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), - alpha, - beta - ); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); - -#else - cutlass::reference::host::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), - alpha, - beta - ); -#endif - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - - CachedTestResultListing cached_results(conv3d_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv3d_result_cache_name); - } - } // if (!cached_result_loaded) - - uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - passed = (tensor_D_hash == cached_test_result.D); - - EXPECT_EQ(tensor_D_hash, cached_test_result.D) - << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; - } - else { - - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - } - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv3d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) - << "ndhwc_" - << problem_size.N << "x" - << problem_size.D << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_ktrsc_" - << problem_size.K << "x" - << problem_size.T << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_d << "x" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_d << "x" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_d << "x" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") - << Conv3d::ThreadblockShape::kM << "x" - << Conv3d::ThreadblockShape::kN << "x" - << Conv3d::ThreadblockShape::kK << "_" - << Conv3d::WarpShape::kM << "x" - << Conv3d::WarpShape::kN << "x" - << Conv3d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n"; - - - results << "\nD reference (hash: " << cached_test_result.D << ")\n"; - - if (!cached_result_loaded) { - results - << tensor_D_reference.host_view() << "\n"; - } - - results - << "\nD computed (hash: " << tensor_D_hash << ")\n" - << tensor_D_computed.host_view() << "\n"; - - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllConv3d( - const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), - const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { - - bool passed = true; - - // - // Testbed object - // - - //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); - TestbedConv3d testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv3d problem sizes to avoid duplicate runs - Conv3dProblemVector conv_tested_sizes; - - Conv3dProblemVector const *problem_vectors[] = { - &conv3d_problems.conv3d_default_sizes, - &conv3d_problems.conv3d_vnet_medical_sizes, - &conv_test_sizes - }; - - // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv3dProblemVector const * problem_vector : problem_vectors) { - - // Run conv testbed on default convolution sizes - for(auto conv_problem : *problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity) || - (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == - cutlass::conv::StrideSupport::kUnity))) { - if (!((conv_problem.stride_d == 1) && - (conv_problem.stride_h == 1) && - (conv_problem.stride_w == 1)) - ) { - continue; - } - } - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - } - - // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for - // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( - {1, 8, 8, 8, 32}, // input size (NDHWC) - {32, 3, 3, 3, 32}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - ); - - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial, - cutlass::conv::SplitKMode::kParallel - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - } - } - } - } - - return passed; -} - -template -bool TestSpecificConv3d( - const Conv3dProblemVector & problem_sizes) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv3d testbed; - - // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for(auto conv_problem : problem_sizes) { - - // - // Test - // - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - - return true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h deleted file mode 100644 index f8ba785c9d0ecbdd518711714558c9e166c0209a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h +++ /dev/null @@ -1,732 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM for fused epilogue broadcast testbed - - Parallel split-k is not tested because we can just use regular conv kernel - when we need to use parallel-splitk. Broadcast can happen in the reduction - kernel. -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/reduction/device/reduce_split_k.h" -#include "cutlass/reduction/thread/reduction_operators.h" - -#include "conv3d_problems.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/tensor_view_io.h" - -#include "../cache_testbed_output.h" - -namespace test { -namespace conv { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Conv3dWithBroadcastReferenceOp { - - using OutputOp = typename Conv3d::EpilogueOutputOp; - - using ElementCompute = typename OutputOp::ElementCompute; - using ElementZ = typename OutputOp::ElementZ; - using ElementT = typename OutputOp::ElementT; - - typename OutputOp::BinaryOp binary_op; - typename OutputOp::ElementwiseOp elementwise_op; - - Conv3dWithBroadcastReferenceOp() { } - - void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) { - ElementCompute t_full = binary_op(conv3d, bias); - T = ElementT(t_full); - - ElementCompute z_full = elementwise_op(t_full); - Z = ElementZ(z_full); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Fused testbed -// -// Y = CONV(AB, C) -// -// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k]) -// -// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k]) -// - -template < - typename Conv3d, - typename ReferenceOp, - bool AddBroadcastFirst = false -> -class TestbedConv3dWithBroadcast { -public: - - using ElementA = typename Conv3d::ElementA; - using LayoutA = typename Conv3d::LayoutA; - using ElementB = typename Conv3d::ElementB; - using LayoutB = typename Conv3d::LayoutB; - using ElementC = typename Conv3d::ElementC; - using LayoutC = typename Conv3d::LayoutC; - using ElementAccumulator = typename Conv3d::ElementAccumulator; - using ElementCompute = typename Conv3d::ElementCompute; - using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; - using ElementZ = typename EpilogueOutputOp::ElementZ; - using ElementT = typename EpilogueOutputOp::ElementT; - using ElementVector = typename EpilogueOutputOp::ElementVector; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; - static const bool kAddBroadcastFirst = AddBroadcastFirst; - static const bool kStoreT = EpilogueOutputOp::kStoreT; - -public: - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_C_reference; - cutlass::HostTensor tensor_Z_computed; - cutlass::HostTensor tensor_Z_reference; - cutlass::HostTensor tensor_T_computed; - cutlass::HostTensor tensor_T_reference; - cutlass::HostTensor tensor_Y_reference; - cutlass::HostTensor tensor_Broadcast; // Input Broadcast - -public: - - TestbedConv3dWithBroadcast( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { - - } - - /// Helper to initialize a tensor view - template - void initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope = 3; - } - else { - scope = 5; - } - } - else { - scope = 8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope, -scope, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - } - } - - void initialize( - cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { - - // to make the layout of tensors a little bit bigger than the problem size - cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); - - cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); - cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); - cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); - - if (non_packed_test) { - tensor_A_extent += stride_increment; - tensor_C_extent += stride_increment; - } - - tensor_A.resize(tensor_A_extent); - tensor_B.resize(tensor_B_extent); - tensor_C.resize(tensor_C_extent); - tensor_C_reference.resize(tensor_C_extent); - tensor_Z_computed.resize(tensor_C_extent); - tensor_Z_reference.resize(tensor_C_extent); - tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Y_reference.resize(tensor_C_extent); - tensor_Broadcast.resize({ - 1, - 1, - 1, - 1, - implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), - }); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); - for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { - for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { - for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { - for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { - for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { - tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k})); - } - } - } - } - } - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_Broadcast.sync_device(); - tensor_C_reference.sync_device(); - tensor_Z_computed.sync_device(); - tensor_Z_reference.sync_device(); - tensor_T_computed.sync_device(); - tensor_T_reference.sync_device(); - tensor_Y_reference.sync_device(); - } - - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::conv::Conv3dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - bool non_packed_test = false, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(1)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 //display conv3d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl - << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl - << std::endl; -#endif - - initialize(problem_size, non_packed_test); - - // configure the operator - Conv3d conv3d_op; - typename Conv3d::Arguments conv3d_args( - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_Z_computed.device_ref(), - {alpha, beta}, - split_k_mode, - tensor_Broadcast.device_data(), - kStoreT ? tensor_T_computed.device_data() : nullptr, - 0, // This must be zero - implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() - ); - - // initialize the kernel - size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // run conv3d operator - status = conv3d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - - tensor_T_computed.sync_host(); - tensor_Z_computed.sync_host(); - - // - // Reference check - // - - // When kAddBroadcastFirst is true, add bias on the host - ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; - -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementAccumulator, - LayoutC, - ElementAccumulator, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C_reference.device_ref(), - tensor_Y_reference.device_ref(), - alpha, - beta_ref); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_Y_reference.sync_host(); - -#else - - cutlass::reference::host::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementAccumulator, - LayoutC, - ElementAccumulator, - ElementAccumulator - >( - kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C_reference.host_ref(), - tensor_Y_reference.host_ref(), - alpha, - beta_ref); - -#endif - ReferenceOp reference_op; - - // compute tensor Z and tensor T - for (int n = 0; n < problem_size.N; ++n) { - for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { - for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { - for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { - for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { - - ElementZ z{}; - ElementT t{}; - - ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k}); - ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k})); - - - if (kAddBroadcastFirst) { - reference_op(z, t, accum + bias, - beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k}))); - } else { - reference_op(z, t, accum, bias); - } - - tensor_Z_reference.at({n, o, p, q, k}) = z; - tensor_T_reference.at({n, o, p, q, k}) = t; - } - } - } - } - } - - if (kStoreT) { - passed = cutlass::reference::host::TensorEquals( - tensor_T_computed.host_view(), - tensor_T_reference.host_view()); - - EXPECT_TRUE(passed); - } - - passed = cutlass::reference::host::TensorEquals( - tensor_Z_computed.host_view(), - tensor_Z_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv3d_ImplicitGemm_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) - << "nnhwc_" - << problem_size.N << "x" - << problem_size.D << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_krsc_" - << problem_size.K << "x" - << problem_size.T << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_d << "x" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_d << "x" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_d << "x" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") - << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") - << Conv3d::ThreadblockShape::kM << "x" - << Conv3d::ThreadblockShape::kN << "x" - << Conv3d::ThreadblockShape::kK << "_" - << Conv3d::WarpShape::kM << "x" - << Conv3d::WarpShape::kN << "x" - << Conv3d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" - << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" - << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" - << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" - << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" - << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes -// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -// (conv_blacklist_sizes) -///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template , - bool AddBroadcastFirst = false, - bool TestSplitK = true -> -bool TestAllConv3dWithBroadcast( - const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), - const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), - bool non_packed_test = false) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv3dWithBroadcast testbed; - - // - // Get conv problem sizes to run conv operator - // - TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); - - // Vector of conv3d problem sizes to avoid duplicate runs - Conv3dProblemVector conv_tested_sizes; - - Conv3dProblemVector const *problem_vectors[] = { - &conv3d_problems.conv3d_default_sizes, - &conv3d_problems.conv3d_vnet_medical_sizes, - &conv_test_sizes - }; - - // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv3dProblemVector const * problem_vector : problem_vectors) { - - // Run conv testbed on default convolution sizes - for(auto conv_problem : *problem_vector) { - - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { - continue; - } - - // - // Procedurally disable certain cases - // - - // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_d == 1) && - (conv_problem.stride_h == 1) && - (conv_problem.stride_w == 1)) - ) { - continue; - } - } - -#if 0 // relax restrictions on analytic strided dgrad - // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || - ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && - (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } -#endif - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial, non_packed_test); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial, non_packed_test); - - if (!passed) { - return false; - } - } - } - - if (!TestSplitK) - return passed; - - // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters - // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep - // alpha and beta for local testing, but only runs one value for alpha and beta. - cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( - {1, 8, 8, 8, 32}, // input size (NDHWC) - {32, 3, 3, 3, 32}, // filter size (KTRSC) - cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) - cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) - cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) - ); - - cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial - }; - - int split_k_slices[] = { - 1, 2, 3, 4, 201 - }; - - double problem_alpha[] = { - 2.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (auto split_k_mode : split_k_modes) { - for (auto split_k_slice : split_k_slices) { - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - passed = testbed.run( - conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - false,/*non_packed_test*/ - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - return false; - } - } - } - } - } - - return passed; -} - -template , - bool AddBroadcastFirst = false> -bool TestSpecificConv3dWithBroadcast( - const Conv3dProblemVector & problem_sizes, - bool non_packed_test = false) { - - bool passed = true; - - // - // Testbed object - // - - TestbedConv3dWithBroadcast testbed; - - // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for(auto conv_problem : problem_sizes) { - - // - // Test - // - - // test mode = xcross, non_packed_test = false - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial, non_packed_test); - - if (!passed) { - return false; - } - - // test mode = convolution, non_packed_test = false - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial, non_packed_test); - - if (!passed) { - return false; - } - } - - return true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h deleted file mode 100644 index cef5f981c595dfbbb95658fb757865b219538192..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h +++ /dev/null @@ -1,473 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Depthwise Direct Conv testbed -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" -#include "../cache_testbed_output.h" -#include "conv2d_problems.h" -#include "cutlass/conv/device/direct_convolution.h" - -#include "cutlass/core_io.h" -#include "cutlass/cutlass.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/device/convolution.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/tensor_view_io.h" - -namespace test { -namespace conv { -namespace device { - -template -class TestbedDepthwiseDirectConv2d { - public: - - using ElementA = typename Conv2d::ElementA; - using LayoutA = typename Conv2d::LayoutA; - using ElementB = typename Conv2d::ElementB; - using LayoutB = typename Conv2d::LayoutB; - using ElementC = typename Conv2d::ElementC; - using LayoutC = typename Conv2d::LayoutC; - using ElementAccumulator = typename Conv2d::ElementAccumulator; - using ElementCompute = typename Conv2d::ElementCompute; - using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; - - static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - - public: - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_reordered_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - int tested_problem_count; - - public: - TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080) - : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} - - /// Helper to initialize a tensor view - template - void initialize_tensor(cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } else if (bits == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope = 3; - } else { - scope = 5; - } - } else { - scope = 8; - } - cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); - } else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - - } else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } else { - } - } - - void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - - initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); - initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_reordered_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); - } - - bool sufficient(int smem_size) const { - // - // Determine SMEM requirements and waive if not satisfied - // - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < static_cast(smem_size)) { - return false; - } - - return true; - } - - /// Executes one test - bool run(cutlass::conv::Conv2dProblemSize const &problem_size, - cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, - ElementCompute alpha = ElementCompute(1.5), - ElementCompute beta = ElementCompute(1)) { - // increment tested problem count run by the testbed - tested_problem_count++; - -#if 0 // display conv2d problem size for debugging - std::cout << problem_size << std::endl - << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl - << "split_k_mode: " - << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") - << std::endl - << std::endl; -#endif - - initialize(problem_size); - - // configure the operator - Conv2d conv2d_op; - - typename Conv2d::Arguments conv2d_args(problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_computed.device_ref(), - {alpha, beta}, - tensor_reordered_B.device_ref(), - split_k_mode); - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = conv2d_op.can_implement(problem_size); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - status = conv2d_op.initialize(conv2d_args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - if (!sufficient(conv2d_op.get_smem_size())) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - // run conv2d operator - status = conv2d_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run." << std::endl; - return false; - } - - bool passed = false; - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); - - tensor_D_computed.sync_host(); - - // - // Reference check - support caching results - // - - CachedTestKey cached_test_key = - CreateCachedConv2dTestKey(kConvolutionalOperator, - problem_size, - alpha, - beta, - tensor_A.host_view(), - tensor_B.host_view(), - tensor_C.host_view()); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string conv2d_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - } - - if (!cached_result_loaded) { -#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - - cutlass::reference::device::Conv2d(kConvolutionalOperator, - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), - alpha, - beta); - - // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); - -#else - - cutlass::reference::host::Conv2d(kConvolutionalOperator, - problem_size, - tensor_A.host_ref(), - tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), - alpha, - beta); - -#endif - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); - } - } // if (!cached_result_loaded) - - uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - passed = (tensor_D_hash == cached_test_result.D); - - EXPECT_EQ(tensor_D_hash, cached_test_result.D) - << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; - } - else { - - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - } - - EXPECT_TRUE(passed); - - std::stringstream ss_problem_size_text; - ss_problem_size_text << "nhwc_" - << problem_size.N << "x" - << problem_size.H << "x" - << problem_size.W << "x" - << problem_size.C - << "_krsc_" - << problem_size.K << "x" - << problem_size.R << "x" - << problem_size.S << "x" - << problem_size.C - << "_padding_" - << problem_size.pad_h << "x" - << problem_size.pad_w - << "_stride_" - << problem_size.stride_h << "x" - << problem_size.stride_w - << "_dilation_" - << problem_size.dilation_h << "x" - << problem_size.dilation_w << "_" - << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); - - if (!passed) { - std::stringstream fname; - - fname << "error_Conv2d_DirectConv_device_" - << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") - << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) - << ss_problem_size_text.str() - << Conv2d::ThreadblockShape::kM << "x" - << Conv2d::ThreadblockShape::kN << "x" - << Conv2d::ThreadblockShape::kK << "_" - << Conv2d::WarpShape::kM << "x" - << Conv2d::WarpShape::kN << "x" - << Conv2d::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n"; - - results << "\nD reference (hash: " << cached_test_result.D << ")\n"; - - if (!cached_result_loaded) { - results - << tensor_D_reference.host_view() << "\n"; - } - - results - << "\nD computed (hash: " << tensor_D_hash << ")\n" - << tensor_D_computed.host_view() << "\n"; - - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { - bool passed = true; - - // - // Testbed object - // - TestbedDepthwiseDirectConv2d testbed; - - // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (auto conv_problem : problem_sizes) { - // - // Test - // - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } - } - - return true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace conv -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp deleted file mode 100644 index 54c11281e14b813b249d7f9710542843b37bcc68..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp +++ /dev/null @@ -1,1385 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem -*/ -#pragma once - -#include "cutlass/conv/convnd_problem_shape.hpp" -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test::conv::device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -std::vector> -inline -get_conv_problem_vector(); - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Fprop -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Specialization for 1D fprop problems -template<> -std::vector> inline -get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nwc - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // non-packed input strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nwc - {800, 80, 1}, // stride (nwc) - {64, 1, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // non-packed output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nwc - {512, 64, 1}, // stride (nwc) - {64, 1, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {800, 80, 1}, // stride (nqk) - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, - {16,1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {96, 1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 64}, - {256, 1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {256, 3, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, symmetric padding with c % cta_k !=0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 3, 32}, - {1}, - {1}, - {1}, - {1}, - 1 - }); - // 4 filter, asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {256, 4, 64}, - {0}, - {1}, - {1}, - {1}, - 1 - }); - // 3 filter, asymmetric padding and tstride of 2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {256, 3, 64}, - {0}, - {1}, - {2}, - {1}, - 1 - }); - // 3 filter, asymmetric padding and dilation of 2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {256, 3, 64}, - {0}, - {1}, - {1}, - {2}, - 1 - }); - return problem_shapes; -} - -// Specialization for 2D fprop problems -template<> -std::vector> inline -get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // nhwc - {64, 1, 1, 64}, // krsc - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // non-packed input strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // nhwc - {8000, 800, 80, 1}, // stride (nhwc) - {64, 1, 1, 64}, // krsc - {64, 64, 64, 1}, // stride (krsc) - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // non-packed output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // nhwc - {4096, 512, 64, 1}, // stride (nhwc) - {64, 1, 1, 64}, // krsc - {64, 64, 64, 1}, // stride (krsc) - {8000, 800, 80, 1}, // stride (npqk) - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, - {16, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 64}, - {96, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 8, 64}, - {256, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 64}, - {256, 3, 3, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, symmetric padding with c % cta_k !=0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 32}, - {256, 3, 3, 32}, - {1, 1}, - {1, 1}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,2/1,2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 64}, - {256, 2, 5, 64}, - {1, 1}, - {2, 2}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 7, 7, 64}, - {256, 2, 5, 64}, - {1, 1}, - {0, 0}, - {2, 3}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 64}, - {256, 2, 5, 64}, - {1, 1}, - {0, 0}, - {1, 1}, - {2, 3}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 15, 64}, - {256, 2, 5, 64}, - {1, 1}, - {0, 0}, - {2, 3}, - {2, 3}, - 1 - }); - return problem_shapes; -} - -// Specialization for 3D fprop problems -template<> -std::vector> inline -get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 8, 8, 64}, // ndhwc - {64, 1, 1, 1, 64}, // ktrsc - {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) - {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 1 // group - }); - // non-packed input output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 8, 8, 64}, // ndhwc - {8000, 8000, 800, 80, 1}, // stride (ndhwc) - {64, 1, 1, 1, 64}, // ktrsc - {64, 64, 64, 64, 1}, // stride (ktrsc) - {8000, 8000, 800, 80, 1}, // stride (nzpqk) - {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) - {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 8, 8, 64}, - {16, 1, 1, 1, 64}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // N = 7 and K = 256 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1, 8, 8, 64}, - {96, 1, 1, 1, 64}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x3x3 + no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 64}, - {96, 3, 3, 3, 64}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x3x3 + symmetric padding with c % cta_k !=0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 32}, - {96, 3, 3, 3, 32}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + symmetric padding 111 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 64}, - {96, 3, 4, 5, 64}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 64}, - {96, 3, 4, 5, 64}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ stride - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 64}, - {96, 3, 4, 5, 64}, - {1, 0, 1}, - {0, 2, 0}, - {2, 2, 3}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 64}, - {96, 3, 4, 5, 64}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {2, 2, 3}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 64}, - {96, 3, 4, 5, 64}, - {1, 0, 1}, - {0, 2, 0}, - {2, 2, 3}, - {2, 2, 3}, - 1 - }); - return problem_shapes; -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Wgrad -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Specialization for 1D wgrad problems -template<> -std::vector> inline -get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nwc - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, - {16,1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 64}, - {96, 1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 64}, - {256, 1, 64}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 3, 32}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, symmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 3, 32}, - {1}, - {1}, - {1}, - {1}, - 1 - }); - // 4 filter, asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 4, 32}, - {0}, - {1}, - {1}, - {1}, - 1 - }); - // 3 filter, asymmetric padding and tstride of 2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 3, 32}, - {0}, - {1}, - {2}, - {1}, - 1 - }); - // 3 filter, asymmetric padding and dilation of 2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 32}, - {256, 3, 32}, - {0}, - {1}, - {1}, - {2}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2048 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1024, 128}, - {640, 1, 128}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2080 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1040, 128}, - {640, 1, 128}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - return problem_shapes; -} - -// Specialization for 2D wgrad problems -template<> -std::vector> inline -get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // nhwc - {64, 1, 1, 64}, // krsc - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, - {16, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 64}, - {96, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 8, 64}, - {256, 1, 1, 64}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 32}, - {256, 3, 3, 32}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, symmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 32}, - {256, 3, 3, 32}, - {1, 1}, - {1, 1}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 32}, - {256, 2, 5, 32}, - {1, 1}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 15, 16, 32}, - {256, 2, 5, 32}, - {1, 1}, - {0, 0}, - {2, 3}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 32}, - {256, 2, 5, 32}, - {1, 1}, - {0, 0}, - {1, 1}, - {2, 3}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 15, 32}, - {256, 2, 5, 32}, - {1, 1}, - {0, 0}, - {2, 3}, - {2, 3}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2048 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 64, 16, 128}, - {640, 1, 1, 128}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2080 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 65, 16, 128}, - {640, 1, 1, 128}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - return problem_shapes; -} - -// Specialization for 3D wgrad problems -template<> -std::vector> inline -get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1, 8, 8, 64}, // ndhwc - {64, 1, 1, 1, 64}, // ktrsc - {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) - {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 1 // group - }); - // Filter 3x3x3 + no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 32}, - {96, 3, 3, 3, 32}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 32}, - {96, 3, 4, 5, 32}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ stride - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 32}, - {96, 3, 4, 5, 32}, - {1, 0, 1}, - {0, 2, 0}, - {2, 2, 3}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 32}, - {96, 3, 4, 5, 32}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {2, 2, 3}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2048 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1, 64, 16, 128}, - {640, 1, 1, 1, 128}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // To test streamk, equals to gemm-MxNxK size 128x640x2080 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1, 65, 16, 128}, - {640, 1, 1, 1, 128}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - return problem_shapes; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Grouped Wgrad -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Get problem size vectors for group conv problems -template -std::vector> -inline -get_grouped_conv_problem_vector(int GroupsPerTile); - -// Specialization for 3D wgrad problems -template<> -std::vector> inline -get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - - if (GroupsPerTile == 1) { - // channel_per_group == 64 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 16, 16, 2048}, // ndhwc - {2048, 1, 3, 3, 64}, // ktrsc - {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) - {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 32 // groups - }); - } - else if (GroupsPerTile == 2) { - // channel_per_group == 32 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 16, 16, 1024}, // ndhwc - {1024, 1, 3, 3, 32}, // ktrsc - {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) - {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 32 // groups - }); - } - else if (GroupsPerTile == 4) { - // channel_per_group == 16 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 16, 16, 512}, // ndhwc - {512, 1, 3, 3, 16}, // ktrsc - {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) - {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 32 // groups - }); - } - else if (GroupsPerTile == 8) { - // channel_per_group == 8 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 16, 16, 256}, // ndhwc - {256, 1, 3, 3, 8}, // ktrsc - {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) - {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 32 // groups - }); - } - return problem_shapes; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Unit Stride Dgrad -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Specialization for 1D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nqk - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // non-packed input strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nqk - {800, 80, 1}, // stride (nqk) - {64, 1, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // non-packed output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 64}, // nqk - {512, 64, 1}, // stride (nqk) - {64, 1, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {800, 80, 1}, // stride (nwc) - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {1}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 16}, - {64, 1, 16}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 96}, - {64, 1, 96}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 256}, - {64, 1, 256}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 256}, - {64, 3, 256}, - {0}, - {0}, - {1}, - {1}, - 1 - }); - // 3 filter, symmetric padding with k % cta_k !=0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 256}, - {32, 3, 256}, - {1}, - {1}, - {1}, - {1}, - 1 - }); - // 4 filter, asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 256}, - {64, 4, 256}, - {0}, - {1}, - {1}, - {1}, - 1 - }); - // 3 filter, asymmetric padding and dilation of 2 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 64}, - {256, 3, 64}, - {0}, - {1}, - {1}, - {2}, - 1 - }); - return problem_shapes; -} - -// Specialization for 2D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // npqk - {64, 1, 1, 64}, // krsc - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // non-packed input strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // npqk - {8000, 800, 80, 1}, // stride (npqk) - {64, 1, 1, 64}, // krsc - {64, 64, 64, 1}, // stride (krsc) - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // non-packed output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 64}, // npqk - {4096, 512, 64, 1}, // stride (npqk) - {64, 1, 1, 64}, // krsc - {64, 64, 64, 1}, // stride (krsc) - {8000, 800, 80, 1}, // stride (nhwc) - {0, 0}, // padding lower (pad_h, pad_w) - {0, 0}, // padding upper (pad_h, pad_w) - {1, 1}, // stride (stride_h, stride_w) - {1, 1}, // dilation (dilation_h, dilation_w) - 1 // group - }); - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 8, 8, 16}, - {64, 1, 1, 16}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 2 and K = 128 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 96}, - {64, 1, 1, 96}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // N = 7 and K = 256 for a even larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {7, 8, 8, 256}, - {64, 1, 1, 256}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, no padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 256}, - {64, 3, 3, 256}, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 3x3 filter, symmetric padding with k % cta_k !=0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 256}, - {32, 3, 3, 256}, - {1, 1}, - {1, 1}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 256}, - {64, 2, 5, 256}, - {1, 1}, - {0, 0}, - {1, 1}, - {1, 1}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 64}, - {256, 2, 5, 64}, - {1, 1}, - {0, 0}, - {1, 1}, - {2, 3}, - 1 - }); - return problem_shapes; -} - -// Specialization for 3D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - // Filter-K = 16 for predication - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 8, 8, 16}, - {64, 1, 1, 1, 16}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // non-packed input output strides. - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1, 8, 8, 64}, // nzpqk - {8000, 8000, 800, 80, 1}, // stride (nzpqk) - {64, 1, 1, 1, 64}, // ktrsc - {64, 64, 64, 64, 1}, // stride (ktrsc) - {8000, 8000, 800, 80, 1}, // stride (ndhwc) - {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) - {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) - {1, 1, 1}, // stride (stride_d, stride_h, stride_w) - {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) - 1 // group - }); - // N = 7 and K = 256 for a larger grid - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 1, 8, 8, 96}, - {64, 1, 1, 1, 96}, - {0, 0, 0}, - {0, 0, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + symmetric padding 111 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 96}, - {64, 3, 4, 5, 96}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010 - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 3, 5, 8, 96}, - {64, 3, 4, 5, 96}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {1, 1, 1}, - 1 - }); - // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 64}, - {64, 3, 4, 5, 96}, - {1, 0, 1}, - {0, 2, 0}, - {1, 1, 1}, - {2, 2, 3}, - 1 - }); - return problem_shapes; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Strided Dgrad -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Specialization for 1D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - // Test TMA truncation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 512, 64}, // nqk - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {2}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 1024, 64}, // nqk - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {4}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {1, 2048, 64}, // nqk - {64, 1, 64}, // ksc - {0}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {8}, // stride (stride_w) - {1}, // dilation (dilation_w) - 1 // group - }); - // non-packed input/output strides. - // stride divides dilation - // asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 8, 64}, // nqk - {800, 80, 1}, // stride (nqk) - {64, 3, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {800, 80, 1}, // stride (nwc) - {0}, // padding lower (pad_w) - {1}, // padding upper (pad_w) - {2}, // stride (stride_w) - {4}, // dilation (dilation_w) - 1 // group - }); - // non-packed input/output strides. - // dilation divides stride - // asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 8, 64}, // nqk - {800, 80, 1}, // stride (nqk) - {64, 3, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {800, 80, 1}, // stride (nwc) - {1}, // padding lower (pad_w) - {0}, // padding upper (pad_w) - {4}, // stride (stride_w) - {2}, // dilation (dilation_w) - 1 // group - }); - // non-packed input/output strides. - // stride dilation dont divide - // asymmetric padding - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 8, 64}, // nqk - {800, 80, 1}, // stride (nqk) - {64, 3, 64}, // ksc - {64, 64, 1}, // stride (ksc) - {800, 80, 1}, // stride (nwc) - {1}, // padding lower (pad_w) - {2}, // padding upper (pad_w) - {2}, // stride (stride_w) - {3}, // dilation (dilation_w) - 1 // group - }); - return problem_shapes; -} - -// Specialization for 2D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - // mode 0 stride divides dilation - // mode 1 dilation divides stride - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 16, 16, 64}, - {256, 2, 5, 64}, - {1, 0}, - {0, 1}, - {2, 4}, - {4, 2}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - // mode 0 dilation divides stride - // mode 1 stride divides dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 16, 16, 64}, - {256, 2, 5, 64}, - {1, 0}, - {0, 1}, - {4, 2}, - {2, 4}, - 1 - }); - // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation - // stride dilation dont divide - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {3, 16, 16, 64}, - {256, 2, 5, 64}, - {1, 0}, - {0, 1}, - {3, 2}, - {2, 3}, - 1 - }); - return problem_shapes; -} - -// Specialization for 3D dgrad problems -template<> -std::vector> inline -get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { - using ProblemShape = cutlass::conv::ConvProblemShape; - std::vector problem_shapes; - // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation - problem_shapes.push_back({ - cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 10, 16, 64}, - {64, 3, 4, 5, 96}, - {1, 0, 1}, - {0, 2, 0}, - {2, 1, 2}, - {4, 2, 3}, - 1 - }); - return problem_shapes; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp deleted file mode 100644 index 99ba9c407cec38e919812fedeee38ba75d9129f7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp +++ /dev/null @@ -1,768 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Implicit GEMM testbed for 3.x API -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "../../common/cutlass_unit_test.h" - -#include "cute/tensor.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/convnd_problem_shape.hpp" -#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" - -#include "thrust/universal_vector.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/host/conv.hpp" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/device/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "conv_problem_sizes.hpp" -#include "../cache_testbed_output.h" - -#include - -#include "cute/layout.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test::conv::device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Initializes a flat device buffer -template -static void -initialize_values( - thrust::universal_vector& dst_ptr, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - if (cutlass::Distribution::Uniform == dist_kind) { - int scope; - int bits = cutlass::sizeof_bits::value; - - if (bits <= 8) { - scope = 2; - } - else if (bits == 16) { - scope = 4; - } - else { - scope = 8; - } - cutlass::reference::host::BlockFillRandomUniform( - dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0); - } - else if (cutlass::Distribution::Identity == dist_kind) { - cutlass::reference::host::BlockFillRandomUniform( - dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0); - } - else if (cutlass::Distribution::Gaussian == dist_kind) { - cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5); - } - else if (cutlass::Distribution::Sequential == dist_kind) { - cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size()); - } - else { - std::cerr << "Invalid distribution kind!\n."; - exit(1); - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// utils for sparse or dense conv parameters - -template -struct DenseConvParams { - // Default Kernel data types - using ElementA = typename Conv::ConvKernel::ElementA; - using ElementB = typename Conv::ConvKernel::ElementB; - - static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; - static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; - using ProblemShape = cutlass::conv::ConvProblemShape; - - // get the default arguments without sparse data - auto get_mainloop_arguments( - [[maybe_unused]] ProblemShape const& problem_shape, - thrust::universal_vector& tensor_A, - thrust::universal_vector& tensor_B - ) { - auto args = typename Conv::ConvKernel::MainloopArguments { - tensor_A.data().get(), - tensor_B.data().get(), - }; - return args; - } -}; - -template -struct SparseConvParams { -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct ConvTestbed { - // Kernel data types - using ElementA = typename Conv::ConvKernel::ElementA; - using ElementB = typename Conv::ConvKernel::ElementB; - using ElementC = cute::conditional_t, - typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>; - using ElementD = typename Conv::ConvKernel::ElementD; - using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; - - // ConvTest for sparse kernel - static constexpr bool isSparseEnabled = isSparseEnabled_; - using ConvParams = cute::conditional_t, DenseConvParams>; - ConvParams params; - - // - // FusionOperation derived types/queries - // - using FusionOp = typename Conv::EpilogueOutputOp; - - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - using ElementScalar = typename FusionOp::ElementScalar; - using ElementCompute = typename FusionOp::ElementCompute; - using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::type; - using ElementBias = non_void_t; - using ActivationType = non_void_t::type, - cutlass::epilogue::thread::Identity>; - static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation::value; - using ActivationFunctor = cute::conditional_t>; - - static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && - !cute::is_same_v; - static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; - - static constexpr bool DisableSource = cute::is_void_v; - - static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; - - using StrideC = typename Conv::ConvKernel::StrideC; - using StrideD = typename Conv::ConvKernel::StrideD; - using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; - - static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; - static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; - using ProblemShape = cutlass::conv::ConvProblemShape; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; - using Splits = typename gemm::device::detail::Splits; - - using Schedule = typename Conv::DispatchPolicy::Schedule; - /// Initialization - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros - uint64_t seed = 6090; - float epsilon = 0.0f; - int split_p_slices = 1; - thrust::universal_vector tensor_A; - thrust::universal_vector tensor_B; - thrust::universal_vector tensor_C; - thrust::universal_vector tensor_D_computed; - thrust::universal_vector tensor_D_reference; - thrust::universal_vector tensor_bias; - thrust::universal_vector tensor_alpha; - thrust::universal_vector tensor_beta; - - // Return true on success, else false - bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { - tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); - tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); - tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); - tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); - tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); - tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); - if constexpr (IsPerChannelScaleEnabled) { - tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); - tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); - } - initialize_values(tensor_A, init_A, seed); - initialize_values(tensor_B, init_B, seed * 11); - initialize_values(tensor_C, init_C, seed * 17); - initialize_values(tensor_bias, init_bias, seed * 19); - if constexpr (IsPerChannelScaleEnabled) { - initialize_values(tensor_alpha, init_bias, seed * 23); - if constexpr (DisableSource) { - initialize_values(tensor_beta, init_disable, seed * 27); - } - else { - initialize_values(tensor_beta, init_bias, seed * 27); - } - } - - bool flag = true; - if constexpr (isSparseEnabled) { - flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); - } - - return flag; - } - - // Determine SMEM requirements and waive if not satisfied - bool sufficient() const { - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - int max_smem_size; - result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaDeviceGetAttribute() failed"); - } - - return max_smem_size >= Conv::ConvKernel::SharedStorageSize; - } - - auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) { - using TensorExtent = cute::array; - using TensorStride = cute::array; - - TensorExtent shape_a_g{}; - TensorExtent shape_b_g{}; - TensorExtent shape_c_g{}; - TensorStride stride_a_g{}; - TensorStride stride_b_g{}; - TensorStride stride_c_g{}; - - auto shape_a = cute::reverse(problem_shape.shape_A); - auto shape_b = cute::reverse(problem_shape.shape_B); - auto shape_c = cute::reverse(problem_shape.shape_C); - auto stride_a = cute::reverse(problem_shape.stride_A); - auto stride_b = cute::reverse(problem_shape.stride_B); - auto stride_c = cute::reverse(problem_shape.stride_C); - - int32_t G = problem_shape.groups; - - if constexpr (ConvOp == cutlass::conv::Operator::kFprop || - ConvOp == cutlass::conv::Operator::kDgrad) { - // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g) - // shape_b_g = (c,s,r,k,t,g) - // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g) - shape_a_g = cute::to_array(tuple_cat( - cute::make_shape(cute::size<0>(shape_a) / G), - cute::take<1,NumSpatialDimensions + 2>(shape_a), - cute::make_shape(G))); - shape_b_g = cute::to_array(tuple_cat( - cute::take<0,NumSpatialDimensions + 1>(shape_b), - cute::make_shape(cute::size(shape_b) / G, G))); - shape_c_g = cute::to_array(tuple_cat( - cute::make_shape(cute::size<0>(shape_c) / G), - cute::take<1,NumSpatialDimensions + 2>(shape_c), - cute::make_shape(G))); - - stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); - stride_b_g = cute::to_array(append(stride_b, - cute::size(stride_b) * cute::size(shape_b) / G)); - stride_c_g = cute::to_array(append(stride_c, cute::size<0>(shape_c) / G)); - } - else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { - // shape_a_g = (k,q,p,z,n,g) - // shape_b_g = (c,w,h,d,n,g) - // shape_c_g = (c,s,r,k,t,g) - shape_a_g = cute::to_array(tuple_cat( - cute::make_shape(cute::size<0>(shape_a) / G), - cute::take<1,NumSpatialDimensions + 2>(shape_a), - cute::make_shape(G))); - shape_b_g = cute::to_array(tuple_cat( - cute::make_shape(cute::size<0>(shape_b) / G), - cute::take<1,NumSpatialDimensions + 2>(shape_b), - cute::make_shape(G))); - shape_c_g = cute::to_array(tuple_cat( - cute::take<0,NumSpatialDimensions + 1>(shape_c), - cute::make_shape(cute::size(shape_c) / G, G))); - - stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); - stride_b_g = cute::to_array(append(stride_b, cute::size<0>(shape_b) / G)); - stride_c_g = cute::to_array(append(stride_c, - cute::size(stride_c) * cute::size(shape_c) / G)); - } - - return make_tuple(shape_a_g, shape_b_g, shape_c_g, - stride_a_g, stride_b_g, stride_c_g); - } - - // Executes one test - bool run( - ProblemShape const& problem_shape, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - dim3 cluster_shape = dim3(0, 0, 0), - dim3 cluster_shape_fallback = dim3(0, 0, 0), - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, - Splits splits = Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic - ) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device.\n"; - } - return true; - } - - bool ret = initialize(problem_shape); - - if (!ret) { - std::cerr << "initialize failed for the given problem_shape: \n"; - return false; - } - - cutlass::KernelHardwareInfo hw_info; - cudaGetDevice(&hw_info.device_id); - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - - hw_info.cluster_shape = cluster_shape; - hw_info.cluster_shape_fallback = cluster_shape_fallback; - - // configure the operator - Conv conv_op; - auto stride_C = StrideC{}; - auto stride_D = StrideD{}; - if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { - stride_C = cutlass::make_cute_packed_stride( - StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); - stride_D = cutlass::make_cute_packed_stride( - StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); - } - // Need to support non-packed output strides for fprop and dgrad kernel. - else { - cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { - cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; - }); - cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { - cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; - }); - } - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; - if constexpr (cute::is_same_v) { - scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; - } - - auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); - - auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { - {}, - tensor_C.data().get(), - stride_C, - tensor_D_computed.data().get(), - stride_D, - }; - - auto args = typename Conv::Arguments { - problem_shape, - mainloop_args, // MainloopArguments - epilogue_args, // EpilogueArguments - hw_info, - scheduler_args - }; - - auto &fusion_args = args.epilogue.thread; - - fusion_args.alpha = alpha; - fusion_args.beta = beta; - - if constexpr (IsPerChannelScaleEnabled) { - fusion_args.alpha_ptr = tensor_alpha.data().get(); - fusion_args.beta_ptr = tensor_beta.data().get(); - } - - if constexpr (IsBiasEnabled) { - fusion_args.bias_ptr = tensor_bias.data().get(); - } - - // Clamp bound - if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); - fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); - } - - // Scale - if constexpr (cute::is_same_v> || - cute::is_same_v> || - cute::is_same_v> || - cute::is_same_v> ) { - fusion_args.activation.scale = ElementCompute{1}; - } - - // LeakyRelu - if constexpr (cute::is_same_v> ) { - fusion_args.activation.leaky_alpha = ElementCompute{0}; - } - - cutlass::Status status = cutlass::Status::kInvalid; - - status = conv_op.can_implement(args); - EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - std::cerr << "can_implement failed for the given problem_shape: \n"; - print(problem_shape); - return false; - } - - // find workspace requirement for parallel split-k reduction - size_t workspace_size = Conv::get_workspace_size(args); - thrust::universal_vector workspace(workspace_size); - - status = conv_op.initialize(args, workspace.data().get()); - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // run conv3d operator - status = conv_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - - bool passed = false; - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " - << cudaGetErrorString(result); - - // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us - auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] = - transform_shape_and_stride_with_groups(problem_shape); - auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B()))); - - auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA)); - auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB)); - auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC)); - auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC)); - auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC)); - auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias)); - auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias)); - auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias)); - - cutlass::reference::host::ConvEpilogueFusionParams< - ElementAccumulator, - ElementScalar, - ElementCompute, - ElementC, - ElementD, - IsResidualEnabled, - decltype(mAlpha), - decltype(mBeta), - decltype(mBias), - ActivationFunctor> - epilogue_fusion_params{}; - - epilogue_fusion_params.alpha = alpha; - epilogue_fusion_params.beta = beta; - - if constexpr (IsPerChannelScaleEnabled) { - epilogue_fusion_params.tensor_alpha = mAlpha; - epilogue_fusion_params.tensor_beta = mBeta; - } - - if constexpr (IsBiasEnabled) { - epilogue_fusion_params.tensor_bias = mBias; - } - - auto padding = cute::reverse(problem_shape.lower_padding); - auto tstride = cute::reverse(problem_shape.traversal_stride); - auto dilation = cute::reverse(problem_shape.dilation); - - cutlass::reference::host::ConvReferenceImpl< - ConvOp, - NumSpatialDimensions, - decltype(mA), - decltype(mB), - decltype(mC), - decltype(mD_ref), - decltype(padding), - decltype(tstride), - decltype(dilation), - decltype(epilogue_fusion_params)> - reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params); - - // - // Reference check - support caching results - // - - CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey< - ProblemShape, - ElementA, - ElementB, - ElementC, - ElementD - >( - ConvOp, - problem_shape, - alpha, - beta, - tensor_A, - tensor_B, - tensor_C - ); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string convnd_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) - CachedTestResultListing cached_results(convnd_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - #endif - - if (!cached_result_loaded) { - // Compute reference - reference_impl.compute_reference(); - - #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) - cached_test_result.D = TensorHash(tensor_D_reference); - CachedTestResultListing cached_results(convnd_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(convnd_result_cache_name); - #endif - } // if (!cached_result_loaded) - - #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) - uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed); - passed = (tensor_D_computed_hash == cached_test_result.D); - // If hash fails, double check against reference implementation. - if(!passed) { - std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key - << ", comparing with reference implementation now.\n"; - if (cached_result_loaded) { - // Compute reference - reference_impl.compute_reference(); - } - // Validate kernel against reference - passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); - } - #else - // Validate kernel against reference - passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); - #endif - - EXPECT_TRUE(passed); - return passed; - } - - template< - class Engine, class Layout, - class EngineA, class LayoutA, - class EngineB, class LayoutB, - class EngineAlpha, class LayoutAlpha, - class EngineBeta, class LayoutBeta, - class EngineBias, class LayoutBias> - static constexpr bool - compare_reference( - cute::Tensor const& reference, - cute::Tensor const& computed, - cute::Tensor const& A, - cute::Tensor const& B, - cute::Tensor const& tensor_alpha, - cute::Tensor const& tensor_beta, - cute::Tensor const& tensor_bias, - float epsilon = 0.0f) { - if (size(reference) != size(computed)) { - return false; - } - - bool passed = true; - if (epsilon == 0.0f) { - // fast refcheck w/o epsilon - for (size_t i = 0; i < size_t(size(reference)); ++i) { - if (reference(i) != computed(i)) { - passed = false; - printf("[%llu] %f, %f\n", static_cast(i), - float(reference(i)), float(computed(i))); - break; - } - } - } else { - // refcheck with epsilon - for (size_t i = 0; i < size_t(size(reference)); ++i) { - auto ref = static_cast(reference(i)); - auto act = static_cast(computed(i)); - auto abs_error = std::abs(act - ref); - auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f); - if (std::isnan(abs_error) || std::isnan(rel_error) || - std::min(abs_error, rel_error) > epsilon) { - passed = false; - printf("[%llu] %f, %f\n", static_cast(i), - float(reference(i)), float(computed(i))); - break; - } - } - } - #if CUTLASS_DEBUG_TRACE_LEVEL > 1 - if (not passed) { - cute::print("Reference:"); - cute::print_tensor(reference); - cute::print("\nComputed:"); - cute::print_tensor(computed); - cute::print("\n"); - - for (size_t i = 0; i < size_t(size(A)); ++i) { - printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); - } - for (size_t i = 0; i < size_t(size(B)); ++i) { - printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); - } - if constexpr (IsPerChannelScaleEnabled) { - for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { - printf("[%llu]: alpha = %f\n", static_cast(i), - float(tensor_alpha(i))); - } - for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { - printf("[%llu]: beta = %f\n", static_cast(i), - float(tensor_beta(i))); - } - } - if constexpr (IsBiasEnabled) { - for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { - printf("[%llu]: bias = %f\n", static_cast(i), - float(tensor_bias(i))); - } - } - for (size_t i = 0; i < size_t(size(reference)); ++i) { - printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), - float(reference(i)), float(computed(i))); - } - } - #endif - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, - dim3 cluster_shape = dim3(0, 0, 0), - dim3 cluster_shape_fallback = dim3(0, 0, 0) - ) { - using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; - - bool passed = true; - ConvTestbed testbed; - testbed.epsilon = epsilon; - auto problem_vector = get_conv_problem_vector< - Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); - - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; - using Splits = typename gemm::device::detail::Splits; - - std::vector decomposition_modes = {DecompositionMode::Heuristic}; - static constexpr bool UsesStreamKScheduler = cute::is_same_v; - if constexpr (UsesStreamKScheduler) { - decomposition_modes.push_back(DecompositionMode::DataParallel); - decomposition_modes.push_back(DecompositionMode::SplitK); - decomposition_modes.push_back(DecompositionMode::StreamK); - } - - for (auto conv_problem : problem_vector) { - #if CUTLASS_DEBUG_TRACE_LEVEL > 0 - print(conv_problem); - #endif - for (DecompositionMode decomp_mode : decomposition_modes) { - std::vector problem_splits = {Splits{1}}; - if constexpr (UsesStreamKScheduler) { - if (decomp_mode == DecompositionMode::SplitK) { - problem_splits.push_back(Splits{2}); - problem_splits.push_back(Splits{4}); - } - } - for (auto splits : problem_splits) { - - passed = testbed.run( - conv_problem, - cutlass::from_real(alpha), - cutlass::from_real(beta), - cluster_shape, - cluster_shape_fallback, - RasterOrderOptions::Heuristic, // raster_order - MaxSwizzleSize(1), - splits, - decomp_mode - ); - if (!passed) { - printf("Failed test for "); print(conv_problem); - return false; - } - } // splits - } // decomposition_mode - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace test::conv::device - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp deleted file mode 100644 index ff170be142ff9d0d02cc684c2873c3ec014bd236..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp +++ /dev/null @@ -1,158 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cutlass_unit_test.h" - -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -using namespace cute; - -template -struct SharedStorage -{ - cute::ArrayEngine> smem; -}; - -template -__global__ void -test_tiled_cp_async_device_cute(T const* g_in, T* g_out, - TiledCopy const tiled_copy, - GmemLayout gmem_layout, SmemLayout smem_layout) -{ - using namespace cute; - - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - - auto thr_copy = tiled_copy.get_slice(threadIdx.x); - Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); - Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); - - // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); - - auto tAgA = thr_copy.partition_S(gA); - auto tAsA = thr_copy.partition_D(sA); - -#if 0 - if (thread0()) { - print("gA : "); print(gA.layout()); print("\n"); - print("sA : "); print(sA.layout()); print("\n"); - print("tAgA: "); print(tAgA.layout()); print("\n"); - print("tAsA: "); print(tAsA.layout()); print("\n"); - } -#endif - - copy(tiled_copy, tAgA, tAsA); - - cp_async_fence(); - cp_async_wait<0>(); - __syncthreads(); - - // Store trivially smem -> gmem - - if (thread0()) { - copy(sA, gB); - } - -} - -template -void -test_tiled_cp_async( - TiledCopy const tiled_copy, - GMEM_Layout const& gmem_layout, - SMEM_Layout const& smem_layout) -{ - using namespace cute; - - // Allocate and initialize host test data - size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); - thrust::host_vector h_in(N); - Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); - for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } - - // Allocate and initialize device test data - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - // Launch - int smem_size = int(sizeof(SharedStorage)); - test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( - reinterpret_cast(raw_pointer_cast(d_in.data())), - reinterpret_cast (raw_pointer_cast(d_out.data())), - tiled_copy, - gmem_layout, - smem_layout); - - // Copy results back to host - thrust::host_vector h_out = d_out; - Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - - // Validate the results. Print only the first 3 errors. - int count = 3; - for (int i = 0; i < size(hA_out) && count > 0; ++i) { - EXPECT_EQ(hA_in(i), hA_out(i)); - if (hA_in(i) != hA_out(i)) { - --count; - } - } -} - -template -void test_cp_async_no_swizzle() { - using namespace cute; - auto smem_atom = SMEM_LAYOUT{}; - auto smem_layout = tile_to_shape(smem_atom, Shape{}); - auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); - test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); -} - -template -void test_cp_async_with_swizzle() { - using namespace cute; - auto swizzle_atom = SWIZZLE_ATOM{}; - auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); - auto smem_layout = tile_to_shape(smem_atom, Shape{}); - auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); - test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp deleted file mode 100644 index 3ff20d4087ee2fd6f4f74338e3e63eef27c221d3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp +++ /dev/null @@ -1,775 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/relatively_equal.h" -#include "cutlass_unit_test.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -#include - -#include -#include - -#include - -using namespace cute; - -template -struct fp64_tester { - using value_type = double; -}; - -template -struct fp64_tester> { - using value_type = complex; -}; - -template // logical shape (M, N) -auto host_generate_gemm_inputs( - ALayout a_layout, - BLayout b_layout, - CLayout c_layout -) { - thrust::host_vector h_a(cosize(a_layout)); - thrust::host_vector h_b(cosize(b_layout)); - thrust::host_vector h_c(cosize(c_layout)); - thrust::host_vector h_c_out(cosize(c_layout)); - - auto h_a_tensor = make_tensor(h_a.data(), a_layout); - auto h_b_tensor = make_tensor(h_b.data(), b_layout); - auto h_c_tensor = make_tensor(h_c.data(), c_layout); - size_t max_size = std::max({static_cast(size(a_layout)), - static_cast(size(b_layout)), - static_cast(size(c_layout))}); - for (size_t i = 0; i < max_size; ++i) { - double di = static_cast(i); - if(i < size(a_layout)) { - h_a_tensor(i) = static_cast(di / size(a_layout)); - } - if(i < size(b_layout)) { - h_b_tensor(i) = static_cast(di / size(a_layout)); - } - if(i < size(c_layout)) { - h_c_tensor(i) = static_cast((di*di) / size(a_layout)); - } - } - - return std::make_tuple(h_a, h_b, h_c, h_c_out); -} - -template -thrust::host_vector -host_reference_gemm(Alpha alpha, - Tensor const& h_a_tensor, - Tensor const& h_b_tensor, - Beta beta, - Tensor const& h_c_tensor, - ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) - { - // Cannot use ::value_type because it propagates to complex::value_type, - // so ViewEngine>::value_type == double - using TA = remove_cv_t; - using TB = remove_cv_t; - using TC = remove_cv_t; - - using tester = fp64_tester; - using ABC_64 = typename tester::value_type; - - static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - - thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); - auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); - // A * B - for (int k = 0; k < size<1>(h_a_tensor); k++) { - for (int m = 0; m < size<0>(h_a_tensor); m++) { - for (int n = 0; n < size<0>(h_b_tensor); n++) { - const auto a_value = a_load_transform(h_a_tensor(m, k)); - const auto b_value = b_load_transform(h_b_tensor(n, k)); - const auto a_value_fp64 = static_cast(a_value); - const auto b_value_fp64 = static_cast(b_value); - h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); - } - } - } - // C = A*B + C - for (int i = 0; i < size(h_c_ref_tensor); i++) { - const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); - const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); - h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); - } - - return h_c_ref; -} - -template -void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, - cute::Tensor const& h_c_ref_tensor) -{ - // Cannot use ::value_type because it propagates to complex::value_type, - // so ViewEngine>::value_type == double - using TC = remove_cv_t; - - using tester = fp64_tester; - using ABC_64 = typename tester::value_type; - - for (int i = 0; i < size(h_c_ref_tensor); i++) { - ABC_64 h_c_ref_i = h_c_ref_tensor(i); - ABC_64 h_c_out_i = h_c_out_tensor(i); - double epsilon(0.1f); - double nonzero_floor(std::numeric_limits::min()); - bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); - ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; - } -} - - -template -__launch_bounds__(ThreadBlockSize) __global__ void -cooperative_gemm_kernel(GMemALayout gmem_a_layout, - GMemBLayout gmem_b_layout, - GMemCLayout gmem_c_layout, - SMemALayout smem_a_layout, - SMemBLayout smem_b_layout, - SMemCLayout smem_c_layout, - TA const* a, - TB const* b, - TC const* c, - TC * c_out, - Alpha const alpha, - Beta const beta, - TiledMma tiled_mma, - ALoadTransform a_load_transform, - BLoadTransform b_load_transform, - CLoadTransform c_load_transform, - CStoreTransform c_store_transform, - SMemCopyOpA a_copy_op, - SMemCopyOpB b_copy_op, - SMemCopyLdOpC c_copy_ld_op, - SMemCopyStOpC c_copy_st_op) -{ - using namespace cute; - - Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); - Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); - Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); - Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); - - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - - extern __shared__ float4 smem_buf[]; - auto* smem_ptr = reinterpret_cast(smem_buf); - auto* smem_ptr_a = smem_ptr; - auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); - auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); - - Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); - Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); - Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); - - cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); - cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); - cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); - - cp_async_fence(); - cp_async_wait<0>(); - __syncthreads(); - - cooperative_gemm( - threadIdx.x, tiled_mma, - alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, - a_load_transform, b_load_transform, c_load_transform, c_store_transform, - a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op - ); - __syncthreads(); - - cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); -} - -template -__launch_bounds__(ThreadBlockSize) __global__ void -cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, - GMemBLayout gmem_b_layout, - GMemCLayout gmem_c_layout, - SMemALayout smem_a_layout, - SMemBLayout smem_b_layout, - TA const* a, - TB const* b, - TC const* c, - TC * c_out, - TiledMma tiled_mma, - ALoadTransform a_load_transform, - BLoadTransform b_load_transform, - CLoadTransform c_load_transform, - CStoreTransform c_store_transform, - SMemCopyOpA a_copy_op, - SMemCopyOpB b_copy_op) - { - using namespace cute; - - Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); - Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); - Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); - Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); - - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - - extern __shared__ float4 smem_buf[]; - auto* smem_ptr = reinterpret_cast(smem_buf); - auto* smem_ptr_a = smem_ptr; - auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); - - Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); - Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); - - cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); - cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); - - cp_async_fence(); - cp_async_wait<0>(); - __syncthreads(); - - // Create C fragment for storing intermediate results - auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); - Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); - Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); - Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); - - // Create indexing help for predicated GEMMs - Tensor cC = make_identity_tensor(shape(gmem_c_layout)); - Tensor tCcC = thr_mma.partition_C(cC); - - // Load C from global - // (always loading in predicated way) - CUTE_UNROLL - for (int i = 0; i < size(r_c_partition); ++i) - { - if (elem_less(tCcC(i), shape(g_c_tensor))) - { - r_c_partition(i) = c_load_transform(g_c_partition(i)); - } - } - - cooperative_gemm( - threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, - a_load_transform, b_load_transform, a_copy_op, b_copy_op - ); - - __syncthreads(); - - // Store C to global - // (always storing in predicated way) - CUTE_UNROLL - for (int i = 0; i < size(r_c_partition); ++i) - { - if (elem_less(tCcC(i), shape(g_c_tensor))) - { - g_c_out_partition(i) = c_store_transform(r_c_partition(i)); - } - } -} - -template, - class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, - class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, - class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> -void test_cooperative_gemm(GMemALayout gmem_a_layout, - GMemBLayout gmem_b_layout, - GMemCLayout gmem_c_layout, - SMemALayout smem_a_layout, - SMemBLayout smem_b_layout, - SMemCLayout smem_c_layout, - TiledMma tiled_mma, - ALoadTransform a_load_transform = {}, - BLoadTransform b_load_transform = {}, - CLoadTransform c_load_transform = {}, - CStoreTransform c_store_transform = {}, - ASMemCopyOp a_smem_copy_op = {}, - BSMemCopyOp b_smem_copy_op = {}, - CSMemCopyLdOp c_smem_copy_ld_op = {}, - CSMemCopyStOp c_smem_copy_st_op = {}) -{ - static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - - static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM - static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN - static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK - - static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM - static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN - static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK - - static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); - static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); - static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); - -#if 0 - print(" "); print("gmem: "); print(gmem_layout); print("\n"); - print(" "); print("smem: "); print(smem_layout); print("\n"); - print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); -#endif - - const auto alpha = static_cast(1.1); - const auto beta = static_cast(1.2); - - // Generate inputs - auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); - - thrust::device_vector d_a(h_a); - thrust::device_vector d_b(h_b); - thrust::device_vector d_c(h_c); - thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); - - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - - const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + - round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + - sizeof(TC) * h_c.size(); - - - auto kernel = cooperative_gemm_kernel< - ThreadBlockSize, CopyMaxVecBits, - GMemALayout, GMemBLayout, GMemCLayout, - SMemALayout, SMemBLayout, SMemCLayout, - TA, TB, TC, decltype(alpha), decltype(beta), - TiledMma, - ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, - ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp - >; - - ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); - - kernel<<<1, ThreadBlockSize, shared_memory_size>>>( - gmem_a_layout, - gmem_b_layout, - gmem_c_layout, - smem_a_layout, - smem_b_layout, - smem_c_layout, - thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c.data()), - thrust::raw_pointer_cast(d_c_out.data()), - alpha, - beta, - tiled_mma, - a_load_transform, - b_load_transform, - c_load_transform, - c_store_transform, - a_smem_copy_op, - b_smem_copy_op, - c_smem_copy_ld_op, - c_smem_copy_st_op - ); - - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - cudaError_t error = cudaGetLastError(); - FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; - } - - // Reference gemm - auto h_c_ref = host_reference_gemm(alpha, - make_tensor(h_a.data(), gmem_a_layout), - make_tensor(h_b.data(), gmem_b_layout), - beta, - make_tensor(h_c.data(), gmem_c_layout), - a_load_transform, - b_load_transform, - c_load_transform, - c_store_transform); - - // Copy result data - h_c_out = d_c_out; - - // Verify correctness - verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), - make_tensor(h_c_ref.data(), gmem_c_layout)); -} - -template, - class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> -void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, - GMemBLayout gmem_b_layout, - GMemCLayout gmem_c_layout, - SMemALayout smem_a_layout, - SMemBLayout smem_b_layout, - TiledMma tiled_mma, - ALoadTransform a_load_transform = {}, - BLoadTransform b_load_transform = {}, - CLoadTransform c_load_transform = {}, - CStoreTransform c_store_transform = {}, - ASMemCopyOp a_smem_copy_op = {}, - BSMemCopyOp b_smem_copy_op = {}) -{ - static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM - static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN - static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK - - static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK - - static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); - static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); - -#if 0 - print(" "); print("gmem: "); print(gmem_layout); print("\n"); - print(" "); print("smem: "); print(smem_layout); print("\n"); - print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); -#endif - - const auto alpha = static_cast(1.0); - const auto beta = static_cast(1.0); - - // Generate inputs - auto [h_a, h_b, h_c, h_c_out] = - host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); - - thrust::device_vector d_a(h_a); - thrust::device_vector d_b(h_b); - thrust::device_vector d_c(h_c); - thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); - - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - - const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + - round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); - - - auto kernel = cooperative_gemm_kernel_rmem_c< - ThreadBlockSize, CopyMaxVecBits, - GMemALayout, GMemBLayout, GMemCLayout, - SMemALayout, SMemBLayout, - TA, TB, TC, - TiledMma, - ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, - ASMemCopyOp, BSMemCopyOp - >; - - ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); - - kernel<<<1, ThreadBlockSize, shared_memory_size>>>( - gmem_a_layout, - gmem_b_layout, - gmem_c_layout, - smem_a_layout, - smem_b_layout, - thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c.data()), - thrust::raw_pointer_cast(d_c_out.data()), - tiled_mma, - a_load_transform, b_load_transform, c_load_transform, c_store_transform, - a_smem_copy_op, b_smem_copy_op - ); - - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - cudaError_t error = cudaGetLastError(); - FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; - } - - // Copy result data - h_c_out = d_c_out; - - // Reference gemm - auto h_c_ref = host_reference_gemm(alpha, - make_tensor(h_a.data(), gmem_a_layout), - make_tensor(h_b.data(), gmem_b_layout), - beta, - make_tensor(h_c.data(), gmem_c_layout), - a_load_transform, - b_load_transform, - c_load_transform, - c_store_transform); - - // Verify correctness - verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), - make_tensor(h_c_ref.data(), gmem_c_layout)); -} - -template -void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, - TiledMma tiled_mma, - Ops ... ops) -{ - auto a_layout = make_layout(select<0, 2>(shape_mnk)); - auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); - auto c_layout = make_layout(select<0, 1>(shape_mnk)); - - test_cooperative_gemm - (a_layout, - b_layout, - c_layout, - a_layout, - b_layout, - c_layout, - tiled_mma, - ops...); -} - - -template -std::enable_if_t, - cute::is_layout, - cute::is_layout>> -test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, - SMemAtomLayoutB smem_atom_layout_b, - SMemAtomLayoutC smem_atom_layout_c, - ShapeMNK shape_mnk, - TiledMma tiled_mma, - Ops&& ... ops) -{ - auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); - auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); - auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); - - auto smem_a_layout = tile_to_shape( - smem_atom_layout_a, - make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); - - auto smem_b_layout = tile_to_shape( - smem_atom_layout_b, - make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); - - auto smem_c_layout = tile_to_shape( - smem_atom_layout_c, - make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); - - test_cooperative_gemm - (gmem_a_layout, - gmem_b_layout, - gmem_c_layout, - smem_a_layout, - smem_b_layout, - smem_c_layout, - tiled_mma, - ops...); -} - - -template -void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, - TiledMma tiled_mma, - Ops ... ops) -{ - auto a_layout = make_layout(select<0, 2>(shape_mnk)); - auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); - auto c_layout = make_layout(select<0, 1>(shape_mnk)); - - - test_cooperative_gemm_rmem_c - (a_layout, - b_layout, - c_layout, - a_layout, - b_layout, - tiled_mma, - ops...); -} - -template -std::enable_if_t, - cute::is_layout>> -test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, - SMemAtomLayoutB smem_atom_layout_b, - ShapeMNK shape_mnk, - TiledMma tiled_mma, - Ops ... ops) -{ - auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); - auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); - auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); - - auto smem_a_layout = tile_to_shape( - smem_atom_layout_a, - make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); - - auto smem_b_layout = tile_to_shape( - smem_atom_layout_b, - make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); - - test_cooperative_gemm_rmem_c - (gmem_a_layout, - gmem_b_layout, - gmem_c_layout, - smem_a_layout, - smem_b_layout, - tiled_mma, - ops...); -} - -template -void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) -{ - test_cooperative_gemm_col_major_layout_rmem_c, - T, T, T> - (static_cast(args)...); -} - -template -void test_cooperative_gemm_col_major_layout(Args&& ... args) -{ - test_cooperative_gemm_col_major_layout, - T, T, T> - (static_cast(args)...); -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp deleted file mode 100644 index 4d2620e62ff247e36ae49809ab4ef3416560ae31..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp +++ /dev/null @@ -1,217 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass_unit_test.h" - -#include -#include - -#include -#include - -#include - -namespace cutlass::test { - -template -struct SharedStorage -{ - cute::ArrayEngine> smem; - alignas(16) cute::uint64_t tma_load_mbar[1]; -}; - -#if CUDA_12_0_SM90_FEATURES_SUPPORTED - -template -__global__ void -tma_test_device_cute(T const* g_in, T* g_out, - CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, - GmemLayout gmem_layout, SmemLayout smem_layout) -{ - using namespace cute; - CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); - - // Use Shared Storage structure to allocate and distribute aligned SMEM addresses - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - - // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) - // Shared memory barriers use 64bits in SMEM for synchronization - uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); - Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); - - constexpr int R = rank_v; - Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - - // - // Prepare the TMA_LOAD - // - - auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) - Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) - -#if 0 - if (thread0()) { - print(tma); - print("TILE : "); print(cta_tiler); print("\n"); - print(" mA : "); print( mA); print("\n"); - print(" mB : "); print( mB); print("\n"); - print(" gA : "); print( gA); print("\n"); - print(" gB : "); print( gB); print("\n"); - print(" sA : "); print( sA); print("\n"); - print("tAgA_x: "); print(tAgA_x); print("\n"); - print("tAsA_x: "); print(tAsA_x); print("\n"); - } -#endif - - // - // Perform the TMA_LOAD - // - - // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles - Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) - Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) - static_assert(size<1>(tAsA) == 1); - - // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output - Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) - -#if 0 - if (thread0()) { - print("tAgA : "); print(tAgA); print("\n"); - print("tAsA : "); print(tAsA); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - } -#endif - - // Test L2 prefetch - if (threadIdx.x == 0) { - prefetch(tma, tAgA); - } - - // Loop over the TMA stages, using smem as our buffer - for (int stage = 0; stage < size<1>(tAgA); ++stage) - { - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA))); - - if (threadIdx.x == 0) - { - /// Initialize shared memory barrier - tma_load_mbar[0] = 0; - cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); - cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); - - copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); - } - __syncthreads(); - - /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value - constexpr int kPhaseBit = 0; - cute::wait_barrier(tma_load_mbar[0], kPhaseBit); - - // - // Write out trivially smem -> gmem - // - - // Subbyte elements could cause race conditions, so be even more conservative - if (thread0()) { - copy(sA, tBgB(_,stage)); - } - - __syncthreads(); - } -} - -template -auto -test_tma_load(CopyOp const& copy_op, - GMEM_Layout const& gmem_layout, - SMEM_Layout const& smem_layout, - CTA_Tile const& cta_tile) -{ - using namespace cute; - - // Allocate and initialize host test data - size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); - thrust::host_vector h_in(N); - for (size_t i = 0; i < h_in.size(); ++i) { - h_in[i] = uint8_t(i % 13); - } - Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); - - // Allocate and initialize device test data - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint - - // Create TMA for this device Tensor - Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); - auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); - //print(tma); - - // Launch - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - reinterpret_cast(raw_pointer_cast(d_in.data())), - reinterpret_cast (raw_pointer_cast(d_out.data())), - tma, cta_tile, - gmem_layout, - smem_layout); - - // Copy results back to host - thrust::host_vector h_out = d_out; - Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - - // Validate the results. Print only the first 3 errors. - int count = 3; - for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { - EXPECT_EQ(hA_in(i), hA_out(i)); - if (hA_in(i) != hA_out(i)) { - --count; - } - } - - return tma; -} - -#endif - -} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp deleted file mode 100644 index 3e0ec46df1b672c35c3c38f731c09b0134d4cd80..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp +++ /dev/null @@ -1,242 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass_unit_test.h" - -#include -#include - -#include -#include - -#include -#include -#include - -namespace cutlass::test { - -template -struct SharedStorage -{ - cute::ArrayEngine> smem; - alignas(16) cute::uint64_t tma_load_mbar[1]; -}; - -#if CUDA_12_0_SM90_FEATURES_SUPPORTED - -template -__global__ void -tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, - CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) -{ - using namespace cute; - CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); - - // Use Shared Storage structure to allocate and distribute aligned SMEM addresses - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - - // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) - // Shared memory barriers use 64bits in SMEM for synchronization - uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); - Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); - - Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) - Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) - -#if 1 - if (thread0()) { - print(tma); - print("TILE : "); print(cta_tiler); print("\n"); - print(" mA : "); print( mA); print("\n"); - print(" mB : "); print( mB); print("\n"); - print(" gA : "); print( gA); print("\n"); - print(" gB : "); print( gB); print("\n"); - print(" sA : "); print( sA); print("\n"); - } __syncthreads(); cute::cluster_sync(); -#endif - - // - // Prepare the TMA_LOAD - // - - Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) - Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) - - int cta_rank_in_cluster = cute::block_rank_in_cluster(); - auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); - -#if 1 - if (thread0()) { - print("sA_x : "); print(sA_x); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - print("tAgA : "); print(tAgA); print("\n"); - print("tAsA : "); print(tAsA); print("\n"); - } __syncthreads(); cute::cluster_sync(); -#endif - - // - // TMA Multicast Masks -- Get a mask of the active ctas in each TMA - // - - - int elected_cta_rank = 0; - bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); - bool elect_one_thr = cute::elect_one_sync(); - - uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); - -#if 1 - if (thread0()) { - print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); - } __syncthreads(); cute::cluster_sync(); -#endif - - // - // Perform the TMA_LOAD - // - - if (elect_one_thr) { - // Initialize TMA barrier - cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); - } - int tma_phase_bit = 0; - // Ensures all CTAs in the Cluster have initialized - __syncthreads(); - cute::cluster_sync(); - - // Loop over the TMA stages, using smem as our buffer - for (int stage = 0; stage < size<1>(tAgA); ++stage) - { - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); - - if (elect_one_thr) - { - cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); - - copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); - } - __syncthreads(); - - /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value - cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); - tma_phase_bit ^= 1; - - // - // Write out trivially smem -> gmem - // - - // Subbyte elements could cause race conditions, so be even more conservative - if (elect_one_cta && elect_one_thr) { - copy(sA, tBgB(_,stage)); - } - - __syncthreads(); - cute::cluster_sync(); - } -} - -template -auto -test_tma_load(CopyOp const& copy_op, - GMEM_Layout const& gmem_layout, - SMEM_Layout const& smem_layout, - CTA_Tiler const& cta_tiler, - Cluster_Size const& cluster_size) -{ - using namespace cute; - - // Allocate and initialize host test data - size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); - thrust::host_vector h_in(N); - for (size_t i = 0; i < h_in.size(); ++i) { - h_in[i] = uint8_t(i % 13); - } - Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); - - // Allocate and initialize device test data - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint - - // Create TMA for this device Tensor - Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); - auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); - //print(tma); - - // Launch - - dim3 dimBlock(32); - dim3 dimCluster(size(cluster_size)); - dim3 dimGrid = dimCluster; - int smem_size = sizeof(SharedStorage); - - void* kernel_ptr = (void*) &tma_test_device_cute; - - cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, - kernel_ptr, - reinterpret_cast(raw_pointer_cast(d_in.data())), - reinterpret_cast(raw_pointer_cast(d_out.data())), - gmem_layout, - smem_layout, - tma, cta_tiler, cluster_size); - - // Copy results back to host - thrust::host_vector h_out = d_out; - Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - - // Validate the results. Print only the first 3 errors. - int count = 3; - for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { - EXPECT_EQ(hA_in(i), hA_out(i)); - if (hA_in(i) != hA_out(i)) { - --count; - } - } - - return tma; -} - -#endif - -} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp deleted file mode 100644 index 0429d2435fbf43c690f311c1f7c04f7025a2dd94..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp +++ /dev/null @@ -1,201 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass_unit_test.h" - -#include -#include - -#include -#include - -#include - -namespace cutlass::test { - -template -struct SharedStorage -{ - cute::ArrayEngine> smem; -}; - -#if CUDA_12_0_SM90_FEATURES_SUPPORTED - -template -__global__ void -tma_test_device_cute(T const* g_in, T* g_out, - CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, - GmemLayout gmem_layout, SmemLayout smem_layout) -{ - using namespace cute; - CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); - - // Use Shared Storage structure to allocate and distribute aligned SMEM addresses - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - - // Construct SMEM tensor - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); - Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); - - constexpr int R = rank_v; - Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - - // - // Prepare the TMA_STORE - // - - auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) - Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) - -#if 0 - if (thread0()) { - print(tma); - print("TILE : "); print(cta_tiler); print("\n"); - print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); - print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); - print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); - print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); - print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); - } -#endif - - // - // Perform the TMA_STORE - // - - // INPUT: Group the CTA_TILE_X modes and REST_X modes for input - Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) - - // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles - Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) - Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) - static_assert(size<1>(tBsB) == 1); - -#if 0 - if (thread0()) { - print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); - print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); - print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); - } -#endif - - // Test L2 prefetch - cooperative_prefetch<128>(threadIdx.x, gA); - - // Loop over the TMA stages, using smem as our buffer - for (int stage = 0; stage < size<1>(tBgB); ++stage) - { - // - // Read in trivially gmem -> smem - // - // Subbyte elements could cause race conditions, so be even more conservative - if (thread0()) { - copy(tAgA(_,stage), sB); - } - - __syncthreads(); - cute::cp_async_wait<0>(); - - // - // Perform the TMA_STORE - // - - if (threadIdx.x == 0) { - copy(tma, tBsB(_,0), tBgB(_,stage)); - } - - tma_store_wait<0>(); - __syncthreads(); - } -} - -template -void -test_tma_store(CopyOp const& copy_op, - GMEM_Layout const& gmem_layout, - SMEM_Layout const& smem_layout, - CTA_Tile const& cta_tile) -{ - using namespace cute; - - // Allocate and initialize host test data - size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); - thrust::host_vector h_in(N); - for (size_t i = 0; i < h_in.size(); ++i) { - h_in[i] = uint8_t(i % 13); - } - Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); - - // Allocate and initialize device test data - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint - - // Create TMA for this device Tensor - Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); - auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); - //print(tma); - - // Launch - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - reinterpret_cast(raw_pointer_cast(d_in.data())), - reinterpret_cast (raw_pointer_cast(d_out.data())), - tma, cta_tile, - gmem_layout, - smem_layout); - - // Copy results back to host - thrust::host_vector h_out = d_out; - Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - - // Validate the results. Print only the first 3 errors. - int count = 3; - for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { - EXPECT_EQ(hA_in(i), hA_out(i)); - if (hA_in(i) != hA_out(i)) { - --count; - } - } -} - -#endif - -} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h deleted file mode 100644 index 3163a0d0eaa24513ee210bd2b310d1bf233773a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h +++ /dev/null @@ -1,417 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - - \brief Unit tests for epilogues -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/half.h" -#include "cutlass/complex.h" - -#include "cutlass/epilogue/thread/linear_combination.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace kernel { - -template -__global__ void epilogue_with_reduction_threadblock( - typename Epilogue::ElementVector *ptr_Reduction, - typename Epilogue::OutputTileIterator::Params params_D, - typename Epilogue::OutputTileIterator::Element *ptr_D, - typename Epilogue::OutputTileIterator::Params params_C, - typename Epilogue::OutputTileIterator::Element *ptr_C, - typename Epilogue::TensorTileIterator::Params params_Tensor, - typename Epilogue::TensorTileIterator::Element *ptr_Tensor, - typename Epilogue::OutputOp::Params params_output_op, - cutlass::MatrixCoord problem_size, - cutlass::TensorRef< - typename Epilogue::WarpMmaOperator::ElementC, - typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, - int epilogue_count = 1) { - - __shared__ typename Epilogue::SharedStorage shared_storage; - - int thread_idx = threadIdx.x; - int warp_idx = threadIdx.x / 32; - int lane_idx = threadIdx.x % 32; - - // - // Construct the epilogue - // - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_D( - params_D, - ptr_D, - problem_size, - thread_idx - ); - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_C( - params_C, - ptr_C, - problem_size, - thread_idx - ); - - // Tile iterator writing to output tile - typename Epilogue::TensorTileIterator iterator_T( - params_Tensor, - ptr_Tensor, - problem_size, - thread_idx - ); - - // Epilogue operator - Epilogue epilogue( - shared_storage, - thread_idx, - warp_idx, - lane_idx); - - // - // Initialize the accumulators - // - - int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); - int warp_m = warp_mn % Epilogue::WarpCount::kM; - int warp_n = warp_mn / Epilogue::WarpCount::kM; - - accumulator_ref.add_coord_offset({ - warp_m * Epilogue::WarpMmaOperator::Shape::kM, - warp_n * Epilogue::WarpMmaOperator::Shape::kN}); - - typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); - - typename Epilogue::AccumulatorTile accumulators; - - accumulators.clear(); - accumulator_iterator.load(accumulators); - -#if 0 - // For debugging, enable this block of code to fill each accumulator element with its - // source thread ID. - CUTLASS_PRAGMA_UNROLL - for (size_t i = 0; i < accumulators.size(); ++i) { - typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); - accumulators[i] = x; - } - - __syncthreads(); - -#endif - - // - // Perform the epilogue operation - // - - typename Epilogue::OutputOp output_op(params_output_op); - - // Place the epilogue in a loop - for (int iter = 0; iter < epilogue_count; ++iter) { - epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); - } -} - -} // namespace kernel -} // namespace test - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Epilogue_ -> -class EpilogueWithReductionTestbed { -public: - - using Epilogue = Epilogue_; - using ElementAccumulator = typename Epilogue::ElementAccumulator; - using ElementCompute = typename Epilogue::OutputOp::ElementCompute; - using ElementTensor = typename Epilogue::TensorTileIterator::Element; - using ElementOutput = typename Epilogue::ElementOutput; - using OutputOpParams = typename Epilogue::OutputOp::Params; - -public: - - // - // Data members - // - - cutlass::MatrixCoord quantized_size; - cutlass::HostTensor accumulator_tensor; - cutlass::HostTensor source_tensor; - cutlass::HostTensor output_tensor; - cutlass::HostTensor additional_tensor; - cutlass::HostTensor reduction_tensor; - - -public: - - // - // Methods - // - - EpilogueWithReductionTestbed(): - quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), - accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - reduction_tensor({1, Epilogue::Shape::kN}) { - - // - // Initialize problem space - // - - uint64_t seed = 2019; - - cutlass::reference::host::TensorFillRandomUniform( - accumulator_tensor.host_view(), - seed, - 20, - -20, - 0); - - cutlass::reference::host::TensorFillRandomUniform( - source_tensor.host_view(), - seed + 2018, - 20, - -20, - 0); - - cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); - } - - bool run_all() { - - /* - double alpha_values[] = {1, 0, 2.25}; - double beta_values[] = {0, 1, -1.25}; - - // Test runtime explodes if we tried to test every case exhaustively. This tests the full - // output tile and several smaller sizes to stress predication. - for (int m_idx = 0; m_idx < 3; ++m_idx) { - for (int n_idx = 0; n_idx < 3; ++n_idx) { - - int m = quantized_size.row() - m_idx * 3; - int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; - - for (double const &alpha : alpha_values) { - for (double const &beta : beta_values) { - - bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); - - if (!passed) { - return false; - } - } - } - } - } - return true; - */ - - double alpha = 1; - double beta = 0; - - return run( - {quantized_size.row(), quantized_size.column()}, - {cutlass::from_real(alpha), cutlass::from_real(beta)}); - } - - /// Runs the test - bool run( - cutlass::MatrixCoord problem_size, - OutputOpParams output_params) { - - // - // Initialize problem space - // - - ElementOutput default_output = ElementOutput(-127); - ElementAccumulator default_reduction = ElementAccumulator(); - - cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); - cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); - - accumulator_tensor.sync_device(); - output_tensor.sync_device(); - source_tensor.sync_device(); - additional_tensor.sync_device(); - reduction_tensor.sync_device(); - - // - // Initialize epilogue parameters - // - - typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); - typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); - typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); - - // - // Launch kernel - // - - dim3 grid(1, 1); - dim3 block(Epilogue::WarpCount::kCount * 32, 1); - - test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( - reduction_tensor.device_data(), - params_D, - output_tensor.device_data(), - params_C, - source_tensor.device_data(), - params_T, - additional_tensor.device_data(), - output_params, - problem_size, - accumulator_tensor.device_view()); - - cudaError_t result = cudaDeviceSynchronize(); - - if (result != cudaSuccess) { - std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; - return false; - } - - // - // Verify results - // - output_tensor.sync_host(); - reduction_tensor.sync_host(); - - int errors = 0; - int const kMaxErrors = 5; - - // - // The output has two parts: - // - GEMM tensor epilogue in canonical layout - // - partial reduction in canonical row-major layout - // - - // Verify the GEMM tensor output - for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { - for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { - - cutlass::MatrixCoord coord{r, c}; - ElementOutput got = output_tensor.at(coord); - - ElementOutput expected; - if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { - - expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + - output_params.beta * ElementCompute(source_tensor.at(coord))); - } - else { - expected = default_output; - } - - if (expected != got) { - - using OutputIO = cutlass::ScalarIO; - - EXPECT_TRUE(false) - << "-------\n" - << "Error - output element (" << coord << ") - expected: " - << OutputIO(expected) - << ", got: " << OutputIO(got) << std::endl; - - ++errors; - } - } - } - - // Verify the partial reduction - for (int c = 0; c < quantized_size.column(); ++c) { - - ElementAccumulator reduction_acc = ElementAccumulator(); - - for (int r = 0; r < quantized_size.row(); ++r) { - reduction_acc += accumulator_tensor.at({r, c}); - } - - ElementAccumulator expected = default_reduction; - ElementAccumulator got = reduction_tensor.at({0, c}); - - if (c < problem_size.column()) { - expected = reduction_acc; - } - else { - expected = default_reduction; - } - - if (expected != got) { - - using OutputIO = cutlass::ScalarIO; - - EXPECT_TRUE(false) - << "-------\n" - << "Error - reduction element (" << c << ") - expected: " - << OutputIO(expected) - << ", got: " << OutputIO(got) << std::endl; - } - } - - // - // Report results on error - // - - if (errors) { - std::stringstream ss; - ss - << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" - << Epilogue::WarpTileIterator::WarpShape::kM << "x" - << Epilogue::WarpTileIterator::WarpShape::kN - << "_slice_" << Epilogue::WarpCount::kK << ".csv"; - - std::ofstream output_file(ss.str()); - output_file << output_tensor.host_view(); - } - - return !errors; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h deleted file mode 100644 index e2457fdb4817e1dfb3af73149ae1e4c4458670a2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h +++ /dev/null @@ -1,356 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for epilogues -*/ -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/half.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/platform/platform.h" -#include "cutlass/epilogue/thread/linear_combination.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace kernel { - -template -__global__ void epilogue_threadblock( - typename Epilogue::OutputTileIterator::Params params_D, - typename Epilogue::OutputTileIterator::Element *ptr_D, - typename Epilogue::OutputTileIterator::Params params_C, - typename Epilogue::OutputTileIterator::Element *ptr_C, - typename Epilogue::OutputOp::Params params_output_op, - cutlass::MatrixCoord problem_size, - cutlass::TensorRef< - typename Epilogue::WarpMmaOperator::ElementC, - typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, - int epilogue_count = 1) { - - __shared__ typename Epilogue::SharedStorage shared_storage; - - int thread_idx = threadIdx.x; - int warp_idx = threadIdx.x / 32; - int lane_idx = threadIdx.x % 32; - - // - // Construct the epilogue - // - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_D( - params_D, - ptr_D, - problem_size, - thread_idx - ); - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_C( - params_C, - ptr_C, - problem_size, - thread_idx - ); - - // Epilogue operator - Epilogue epilogue( - shared_storage, - thread_idx, - warp_idx, - lane_idx); - - // - // Initialize the accumulators - // - - int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); - int warp_m = warp_mn % Epilogue::WarpCount::kM; - int warp_n = warp_mn / Epilogue::WarpCount::kM; - - accumulator_ref.add_coord_offset({ - warp_m * Epilogue::WarpMmaOperator::Shape::kM, - warp_n * Epilogue::WarpMmaOperator::Shape::kN}); - - typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); - - typename Epilogue::AccumulatorTile accumulators; - - accumulators.clear(); - accumulator_iterator.load(accumulators); - -#if 0 - // For debugging, enable this block of code to fill each accumulator element with its - // source thread ID. - CUTLASS_PRAGMA_UNROLL - for (size_t i = 0; i < accumulators.size(); ++i) { - typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); - accumulators[i] = x; - } - - __syncthreads(); - -#endif - - // - // Perform the epilogue operation - // - - typename Epilogue::OutputOp output_op(params_output_op); - - // Place the epilogue in a loop - for (int iter = 0; iter < epilogue_count; ++iter) { - epilogue(output_op, iterator_D, accumulators, iterator_C); - } -} - -} // namespace kernel -} // namespace test - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Epilogue_ -> -class EpilogueTestbed { -public: - - using Epilogue = Epilogue_; - using ElementAccumulator = typename Epilogue::ElementAccumulator; - using ElementCompute = typename Epilogue::OutputOp::ElementCompute; - using ElementOutput = typename Epilogue::ElementOutput; - using OutputOpParams = typename Epilogue::OutputOp::Params; - -public: - - // - // Data members - // - - cutlass::MatrixCoord quantized_size; - cutlass::HostTensor accumulator_tensor; - cutlass::HostTensor source_tensor; - cutlass::HostTensor output_tensor; - -public: - - // - // Methods - // - - EpilogueTestbed(): - quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), - accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { - - // - // Initialize problem space - // - - uint64_t seed = 2019; - - cutlass::reference::host::TensorFillRandomUniform( - accumulator_tensor.host_view(), - seed, - 2, - -2, - 0); - - cutlass::reference::host::TensorFillRandomUniform( - source_tensor.host_view(), - seed + 2018, - 2, - -2, - 0); - } - - bool run_all() { - - double alpha_values[] = {1, 0, 2.25}; - double beta_values[] = {0, 1, -1.25}; - - // Test runtime explodes if we tried to test every case exhaustively. This tests the full - // output tile and several smaller sizes to stress predication. - for (int m_idx = 0; m_idx < 3; ++m_idx) { - for (int n_idx = 0; n_idx < 3; ++n_idx) { - - int m = quantized_size.row() - m_idx * 3; - int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; - - for (double const &alpha : alpha_values) { - for (double const &beta : beta_values) { - - bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); - - if (!passed) { - return false; - } - } - } - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::MatrixCoord problem_size, - OutputOpParams output_params) { - - // - // Initialize problem space - // - - ElementOutput default_output = ElementOutput(-127); - cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); - - accumulator_tensor.sync_device(); - output_tensor.sync_device(); - source_tensor.sync_device(); - - // - // Initialize epilogue parameters - // - - typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); - typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); - - // - // Launch kernel - // - - dim3 grid(1, 1); - dim3 block(Epilogue::WarpCount::kCount * 32, 1); - - test::kernel::epilogue_threadblock<<< grid, block >>>( - params_D, - output_tensor.device_data(), - params_C, - source_tensor.device_data(), - output_params, - problem_size, - accumulator_tensor.device_view()); - - cudaError_t result = cudaDeviceSynchronize(); - - if (result != cudaSuccess) { - std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; - return false; - } - - // - // Verify results - // - output_tensor.sync_host(); - - int errors = 0; - int const kMaxErrors = 5; - - for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { - for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { - - cutlass::MatrixCoord coord{r, c}; - ElementOutput got = output_tensor.at(coord); - - ElementOutput expected; - if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { - ElementCompute intermediate = - output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + - output_params.beta * ElementCompute(source_tensor.at(coord)); - - if ((cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || std::numeric_limits::is_integer) - && !std::numeric_limits::is_integer) { - std::fesetround(FE_TONEAREST); - expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); - } else { - expected = ElementOutput(intermediate); - } - } else { - expected = default_output; - } - - if (expected != got) { - - using OutputIO = cutlass::ScalarIO; - - EXPECT_TRUE(false) - << "-------\n" - << "Error - output element (" << coord << ") - expected: " - << OutputIO(expected) - << ", got: " << OutputIO(got) - << ", accum: " << (accumulator_tensor.at(coord)) - << ", source: " << OutputIO(source_tensor.at(coord)) - << ", alpha: " << (output_params.alpha) - << ", beta: " << (output_params.beta) << "\n"; - - ++errors; - } - } - } - - // - // Report results on error - // - - if (errors) { - std::stringstream ss; - ss - << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" - << Epilogue::WarpTileIterator::WarpShape::kM << "x" - << Epilogue::WarpTileIterator::WarpShape::kN - << "_slice_" << Epilogue::WarpCount::kK << ".csv"; - - std::ofstream output_file(ss.str()); - output_file << output_tensor.host_view(); - } - - return !errors; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h deleted file mode 100644 index a76578f7638ac1d30161a9bcb55ecec70b5c43e0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h +++ /dev/null @@ -1,394 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for epilogues -*/ -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/half.h" -#include "cutlass/complex.h" - -#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" - -#include "cutlass/util/host_tensor_planar_complex.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace kernel { - -template -__global__ void epilogue_planar_complex_threadblock( - typename Epilogue::OutputTileIterator::Params params_D, - typename Epilogue::OutputTileIterator::Element *ptr_D, - int64_t imaginary_stride_D, - typename Epilogue::OutputTileIterator::Params params_C, - typename Epilogue::OutputTileIterator::Element *ptr_C, - int64_t imaginary_stride_C, - typename Epilogue::OutputOp::Params params_output_op, - cutlass::MatrixCoord problem_size, - cutlass::TensorRef< - typename Epilogue::WarpMmaOperator::ElementC, - typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, - int64_t imaginary_stride_accum, - int epilogue_count = 1) { - - __shared__ typename Epilogue::SharedStorage shared_storage; - - int thread_idx = threadIdx.x; - int warp_idx = threadIdx.x / 32; - int lane_idx = threadIdx.x % 32; - - // - // Construct the epilogue - // - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_D_real( - params_D, - ptr_D, - problem_size, - thread_idx - ); - - typename Epilogue::OutputTileIterator iterator_D_imag( - params_D, - ptr_D + imaginary_stride_D, - problem_size, - thread_idx - ); - - // Tile iterator writing to output tile - typename Epilogue::OutputTileIterator iterator_C_real( - params_C, - ptr_C, - problem_size, - thread_idx - ); - - typename Epilogue::OutputTileIterator iterator_C_imag( - params_C, - ptr_C + imaginary_stride_C, - problem_size, - thread_idx - ); - - // Epilogue operator - Epilogue epilogue( - shared_storage, - thread_idx, - warp_idx, - lane_idx); - - // - // Initialize the accumulators - // - - int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); - int warp_m = warp_mn % Epilogue::WarpCount::kM; - int warp_n = warp_mn / Epilogue::WarpCount::kM; - - accumulator_ref.add_coord_offset({ - warp_m * Epilogue::WarpMmaOperator::Shape::kM, - warp_n * Epilogue::WarpMmaOperator::Shape::kN}); - - // - // Load accumulators - // - - typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); - - typename Epilogue::AccumulatorTile accumulators; - - accumulators.clear(); - - accumulator_iterator.load(accumulators.real); - accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); - - // - // Perform the epilogue operation - // - - typename Epilogue::OutputOp output_op(params_output_op); - - // Place the epilogue in a loop so assembly is clearly visible - for (int iter = 0; iter < epilogue_count; ++iter) { - epilogue( - output_op, - iterator_D_real, - iterator_D_imag, - accumulators, - iterator_C_real, - iterator_C_imag); - } -} - -} // namespace kernel -} // namespace test - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Epilogue_ -> -class EpiloguePlanarComplexTestbed { -public: - - using Epilogue = Epilogue_; - using ElementAccumulator = typename Epilogue::ElementAccumulator; - using ElementCompute = typename Epilogue::OutputOp::ElementCompute; - using ElementOutput = typename Epilogue::ElementOutput; - using OutputOpParams = typename Epilogue::OutputOp::Params; - - using ComplexElementOutput = cutlass::complex; - using ComplexElementAccumulator = cutlass::complex; - using ComplexElementCompute = cutlass::complex; - -public: - - // - // Data members - // - - cutlass::MatrixCoord quantized_size; - cutlass::HostTensorPlanarComplex accumulator_tensor; - cutlass::HostTensorPlanarComplex source_tensor; - cutlass::HostTensorPlanarComplex output_tensor; - -public: - - // - // Methods - // - - EpiloguePlanarComplexTestbed(): - quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), - accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), - output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { - - // - // Initialize problem space - // - - #if 1 - uint64_t seed = 2019; - - cutlass::reference::host::TensorFillRandomUniform( - accumulator_tensor.host_view(), - seed, - 20, - -20, - 0); - - cutlass::reference::host::TensorFillRandomUniform( - source_tensor.host_view(), - seed + 2018, - 20, - -20, - 0); - #else - - cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); - - #endif - } - - bool run_all() { - - cutlass::complex alpha_values[3]; - - alpha_values[0] = cutlass::complex(1, 0); - alpha_values[1] = cutlass::complex(0, 0); - alpha_values[2] = cutlass::complex(2.25f, -0.5f); - - cutlass::complex beta_values[3]; - - beta_values[0] = cutlass::complex(0, 0); - beta_values[1] = cutlass::complex(1, 0); - beta_values[2] = cutlass::complex(0.5f, -2.25f); - - // Test runtime explodes if we tried to test every case exhaustively. This tests the full - // output tile and several smaller sizes to stress predication. - for (int m_idx = 0; m_idx < 3; ++m_idx) { - for (int n_idx = 0; n_idx < 3; ++n_idx) { - - cutlass::MatrixCoord problem_size( - quantized_size.row() - m_idx * 3, - quantized_size.column() - n_idx * Epilogue::kElementsPerAccess - ); - - for (auto const &alpha : alpha_values) { - for (auto const &beta : beta_values) { - - bool passed = run(problem_size, {alpha, beta}); - - if (!passed) { - return false; - } - } - } - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::MatrixCoord problem_size, - OutputOpParams output_params) { - - // - // Initialize problem space - // - - ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); - - cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); - - accumulator_tensor.sync_device(); - output_tensor.sync_device(); - source_tensor.sync_device(); - - // - // Initialize epilogue parameters - // - - typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); - typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); - - // - // Launch kernel - // - - dim3 grid(1, 1); - dim3 block(Epilogue::WarpCount::kCount * 32, 1); - - test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( - params_D, - output_tensor.device_data(), - output_tensor.imaginary_stride(), - params_C, - source_tensor.device_data(), - source_tensor.imaginary_stride(), - output_params, - problem_size, - accumulator_tensor.device_view_real(), - accumulator_tensor.imaginary_stride() - ); - - cudaError_t result = cudaDeviceSynchronize(); - - if (result != cudaSuccess) { - std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; - return false; - } - - // - // Verify results - // - output_tensor.sync_host(); - - int errors = 0; - int const kMaxErrors = 5; - - for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { - for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { - - cutlass::MatrixCoord coord{r, c}; - ComplexElementOutput got = output_tensor.at(coord); - - ComplexElementOutput expected = default_output; - - if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { - - ComplexElementOutput src = source_tensor.at(coord); - - ComplexElementCompute tmp = - output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + - output_params.beta * ComplexElementCompute(src.real(), src.imag()); - - expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); - } - - if (expected != got) { - - using OutputIO = cutlass::ScalarIO; - - EXPECT_TRUE(false) - << "-------\n" - << "Error - output element (" << coord << ") - expected: " - << OutputIO(expected) - << ", got: " << OutputIO(got) << std::endl; - - ++errors; - } - } - } - - // - // Report results on error - // - - if (errors) { - - - std::cout << "Incorrect result for problem(" - << problem_size.row() << ", " - << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; - - std::stringstream ss; - ss - << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" - << Epilogue::WarpTileIterator::WarpShape::kM << "x" - << Epilogue::WarpTileIterator::WarpShape::kN - << "_slice_" << Epilogue::WarpCount::kK << ".csv"; - - std::ofstream output_file(ss.str()); - output_file << output_tensor.host_view(); - - std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; - } - - return !errors; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp deleted file mode 100644 index 0054a1b6757a232e9177407fdd2041b6a91cffb9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp +++ /dev/null @@ -1,1384 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/atom/mma_atom.hpp" -#include "cute/atom/copy_atom.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/layout/layout.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/collective/collective_mma.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" - -namespace cutlass { -namespace gemm { -namespace device { -using namespace cute; - -// This type is only intended to demonstrate porting 2.x kernels to 3.0 -template< - class OperatorClass, class ArchTag, - class ElementA, class LayoutA, - class ElementB, class LayoutB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types { - static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); -}; - -/////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct DefaultGemm_TensorOpSm80_OperandA; - -template -struct DefaultGemm_TensorOpSm80_OperandB; - -// -// F16: 128-by-128-by-64 -// - -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride<_64, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); -}; - -/// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride< _1,_64>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _1,_16>>{}, - Layout>{})); -}; - -// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands - -// Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; - -// Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; - -// -// F16: 128-by-128-by-32 (small k-block) -// - -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,3,3>{}, - Layout, - Stride<_32, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>{})); -}; - -} - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere MMA F32F16 -template -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - half_t, LayoutA, - half_t, LayoutB, - float, LayoutC, - float> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, // 2x2x1 thread group - Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group - - // A - static constexpr int kAlignmentA = 8; - using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< - half_t, LayoutA, kAlignmentA, 32>; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - // B - static constexpr int kAlignmentB = 8; - using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< - half_t, LayoutB, kAlignmentB, 32>; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - half_t, TagToStrideA_t, - half_t, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - float, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -// -// TF32: 128-by-128-by-kblock (kBlock = 16, 32) -// - -/// Operand A - Row-major (K-major) (kBlock = 32) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,2,3>{}, - Layout, - Stride<_32, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); -}; - -/// Operand A - Row-major (K-major) (kBlock = 16) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,2,3>{}, - Layout, - Stride<_16, _1>>{})); - using SmemCopyAtom = Copy_Atom; - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>{})); -}; - -/// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,2,3>{}, - Layout, - Stride< _1,_32>>{})); - using SmemCopyAtom = Copy_Atom, tfloat32_t>; - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _1,_16>>{}, - Layout>{})); -}; - -// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands - -// Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; - -// Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; - -} - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere MMA F32TF32 -template -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - tfloat32_t, LayoutA, - tfloat32_t, LayoutB, - float, LayoutC, - float> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, - Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group - Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group - - // A - static constexpr int kAlignmentA = 4; - using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< - tfloat32_t, LayoutA, kAlignmentA, 32>; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - // B - static constexpr int kAlignmentB = 4; - using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< - tfloat32_t, LayoutB, kAlignmentB, 32>; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - tfloat32_t, TagToStrideA_t, - tfloat32_t, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - float, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// -template -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - int8_t, cutlass::layout::RowMajor, - int8_t, cutlass::layout::ColumnMajor, - int32_t, LayoutC, - int32_t> -{ - using TileShape = Shape<_128, _128, _64>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, // 2x2x1 thread group - Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group - - // A (M,K) K-major - using SmemLayoutAtomA = decltype( - composition( - Swizzle<2,4,3>{}, - Layout, - Stride<_64, _1>>{})); - static constexpr int kAlignmentA = 16; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, int8_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>>{})); - // LDS.32- or LDSM-based copy atom - // using SmemCopyAtomA = Copy_Atom; - using SmemCopyAtomA = Copy_Atom; // LDSM works - - // B (N,K) K-major - using SmemLayoutAtomB = decltype( - composition( - Swizzle<2,4,3>{}, - Layout, - Stride<_64, _1>>{})); - static constexpr int kAlignmentB = 16; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, int8_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>>{})); - - // LDS.32- or LDSM-based copy atom - // using SmemCopyAtomB = Copy_Atom; - using SmemCopyAtomB = Copy_Atom; // LDSM works - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - int8_t, TagToStrideA_t, - int8_t, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - int32_t, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// -//////////////////////////// SIMT TWO STAGE /////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct DefaultGemm_Simt_OperandA; - -/////////////////////////////////////////////////////////////////////////////// - -template -struct DefaultGemm_Simt_OperandA -{ - using SmemLayoutAtom = Layout, - Stride< _1,_128>>; - - using SmemCopyAtom = Copy_Atom; - - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - Layout, - Stride< _1,_32>>{}, - Layout>{})); -}; - -template -struct DefaultGemm_Simt_OperandA -{ - using SmemLayoutAtom = Layout, - Stride< _1,Int<128 + 4>>>; // Padded - - using SmemCopyAtom = Copy_Atom; - - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - Layout, - Stride< _8, _1>>{}, - Layout>{})); - -}; - -template -struct DefaultGemm_Simt_OperandB; - -template -struct DefaultGemm_Simt_OperandB - : DefaultGemm_Simt_OperandA {}; - -template -struct DefaultGemm_Simt_OperandB - : DefaultGemm_Simt_OperandA {}; - -} // end namespace detail - -// SIMT Two Stage -template < - class ArchTag, - class ElementA, class LayoutA, - class ElementB, class LayoutB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, ArchTag, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator> -{ - using TileShape = Shape<_128, _128, _8>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm70TwoStage; - using TiledMma = TiledMMA< - MMA_Atom>, - Layout>>; - - // A - static constexpr int kAlignmentA = 1; - using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - // B - static constexpr int kAlignmentB = 1; - using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - - -// -// DP4A - int8 Proof-of-concept -// - -// SIMT Two Stage TN - idp4a -template < - class ArchTag, - class ElementC, class LayoutC> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, ArchTag, - int8_t, cutlass::layout::RowMajor, - int8_t, cutlass::layout::ColumnMajor, - ElementC, LayoutC, - int32_t> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm70TwoStage; - // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts - using TiledMma = TiledMMA< - MMA_Atom, - Layout>>; // Tile of atoms (threads) - - // A (M,K) K-major - using ElementA = int8_t; - // 40% from regular M and N major layout - // using SmemLayoutAtomA = Layout, - // Stride< _1,_128>>; - // 80% from interleaved layouts - using SmemLayoutAtomA = Layout>, - Stride< _4, Stride<_1,_512>>>; - - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 4; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // B (N,K) K-major - using ElementB = int8_t; - // 40% from regular M and N major layout - // using SmemLayoutAtomB = Layout, - // Stride< _1,_128>>; - // 80% from interleaved layouts - using SmemLayoutAtomB = Layout>, - Stride< _4, Stride<_1,_512>>>; - - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 4; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Two Stage NN - idp4a -template < - class ArchTag, - class ElementC, class LayoutC> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, ArchTag, - int8_t, cutlass::layout::ColumnMajor, - int8_t, cutlass::layout::ColumnMajor, - ElementC, LayoutC, - int32_t> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 256; - - using DispatchPolicy = MainloopSm70TwoStage; - - using TiledMma = TiledMMA< - MMA_Atom, - Layout>>; - - // A (M,K) M-major - using ElementA = int8_t; - using SmemLayoutAtomA = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride< _1,_32>>{}, - Layout>{})); - - // B (N,K) K-major - using ElementB = int8_t; - using SmemLayoutAtomB = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 4; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Two Stage NT - idp4a -template < - class ArchTag, - class ElementC, class LayoutC> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, ArchTag, - int8_t, cutlass::layout::ColumnMajor, - int8_t, cutlass::layout::RowMajor, - ElementC, LayoutC, - int32_t> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm70TwoStage; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>>; - - // A (M,K) M-major - using ElementA = int8_t; - using SmemLayoutAtomA = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride< _1,_32>>{}, - Layout>{})); - - // B (N,K) N-major - using ElementB = int8_t; - using SmemLayoutAtomB = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride< _1,_32>>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Two Stage TT - idp4a -template < - class ArchTag, - class ElementC, class LayoutC> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, ArchTag, - int8_t, cutlass::layout::RowMajor, - int8_t, cutlass::layout::RowMajor, - ElementC, LayoutC, - int32_t> -{ - using TileShape = Shape<_128, _128, _32>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm70TwoStage; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>>; - - // A (M,K) K-major - using ElementA = int8_t; - using SmemLayoutAtomA = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 4; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // B (N,K) N-major - using ElementB = int8_t; - using SmemLayoutAtomB = Layout>, - Stride< _4, Stride<_1,_512>>>; - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride< _1,_32>>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// -/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Multi Stage NT -template < - class ElementA, - class ElementB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, arch::Sm80, - ElementA, cutlass::layout::ColumnMajor, - ElementB, cutlass::layout::RowMajor, - ElementC, LayoutC, - ElementAccumulator> -{ - using TileShape = Shape<_128, _128, _16>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom>, - Layout>, // 16x16x1 thread group - Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization - Layout,Stride<_2,_1>>,Underscore>>; - - // A (M,K) M-major - using SmemLayoutAtomA = Layout>; - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 2; - using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout>{}, - Layout>{})); - - // B (N,K) N-major - using SmemLayoutAtomB = Layout>; - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 2; - using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Multi Stage TN -template < - class ElementA, - class ElementB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, arch::Sm80, - ElementA, cutlass::layout::RowMajor, - ElementB, cutlass::layout::ColumnMajor, - ElementC, LayoutC, - ElementAccumulator> -{ - using TileShape = Shape<_128, _128, _16>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom>, - Layout>>; - - // A (M,K) K-major - using SmemLayoutAtomA = Layout, - Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride<_16, _1>>{})); - - // B (N,K) K-major - using SmemLayoutAtomB = Layout, - Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride<_16, _1>>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Multi Stage NN -template < - class ElementA, - class ElementB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, arch::Sm80, - ElementA, cutlass::layout::ColumnMajor, - ElementB, cutlass::layout::ColumnMajor, - ElementC, LayoutC, - ElementAccumulator> -{ - using TileShape = Shape<_128, _128, _16>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom>, - Layout>, // 16x16x1 thread group - Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization - - // A (M,K) M-major - using SmemLayoutAtomA = Layout>; - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 2; - using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout>{}, - Layout>{})); - - // B (N,K) K-major - using SmemLayoutAtomB = Layout, - Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout, - Stride<_16, _1>>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// SIMT Multi Stage TT -template < - class ElementA, - class ElementB, - class ElementC, class LayoutC, - class ElementAccumulator> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassSimt, arch::Sm80, - ElementA, cutlass::layout::RowMajor, - ElementB, cutlass::layout::RowMajor, - ElementC, LayoutC, - ElementAccumulator> -{ - using TileShape = Shape<_128, _128, _16>; - static constexpr int ThreadCount = 256; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom>, - Layout>, // 16x16x1 thread group - Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization - - // A (M,K) K-major - using SmemLayoutAtomA = Layout, - Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, ElementA>{}, - Layout, - Stride<_16, _1>>{})); - - // B (N,K) N-major - using SmemLayoutAtomB = Layout>; - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 2; - using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, ElementB>{}, - Layout>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - ElementA, TagToStrideA_t, - ElementB, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - ElementC, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere fp64 MMA TN (K-Major A and K-Major B) -template <> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - double, cutlass::layout::RowMajor, - double, cutlass::layout::ColumnMajor, - double, cutlass::layout::ColumnMajor, - double> -{ - using TileShape = Shape<_128, _64, _16>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, // Atom - Layout>, // Atom layout - Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization - Layout,Stride<_2,_1>>, - Underscore>>; - - // A (M,K) K-Major - using SmemLayoutAtomA = decltype( - composition(Swizzle<2,0,4>{}, - Layout, - Stride<_1, _4>>{})); // M, K - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride<_16, _1>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 1x1 doubles - - // B (N,K) K-Major - using SmemLayoutAtomB = decltype( - composition(Swizzle<2,0,4>{}, - Layout, - Stride<_1, _4>>{})); // N, K - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride<_16, _1>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 1x1 doubles - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - double, TagToStrideA_t, - double, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - double, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; - -/* - using EpilogueOutputOp = epilogue::collective::Epilogue< - epilogue::thread::LinearCombination, - Layout, - Stride< _1,_64>>, // SMEM layout - Copy_Atom,double>, // R2S with tiled_mma layout - decltype(make_tiled_copy(Copy_Atom,double>{},// S2R - Layout, - Stride< _1,_16>>{}, // Thread layout - Layout>{})), // Value layout - Copy_Atom,double> // R2G with S2R_dst layout - >; -*/ -}; - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere fp64 MMA NN (M-Major A and K-Major B) -template <> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - double, cutlass::layout::ColumnMajor, - double, cutlass::layout::ColumnMajor, - double, cutlass::layout::ColumnMajor, - double> -{ - using TileShape = Shape<_128, _64, _16>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, // Atom - Layout>, // Atom layout - Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization - Layout,Stride<_2,_1>>, - Underscore>>; - - // A (M,K) M-Major - using SmemLayoutAtomA = decltype( - composition(Swizzle<2,2,2>{}, - Layout, - Stride< _1,_16>>{})); // M, K - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 2; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride< _1,_16>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 2x1 doubles - - // B (N,K) K-Major - using SmemLayoutAtomB = decltype( - composition(Swizzle<2,0,4>{}, - Layout, - Stride<_1, _4>>{}));// N, K - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 1; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride<_16, _1>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 1x1 doubles - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - double, TagToStrideA_t, - double, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - double, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere fp64 MMA NT (M-Major A and N-Major B) -template <> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - double, cutlass::layout::ColumnMajor, - double, cutlass::layout::RowMajor, - double, cutlass::layout::ColumnMajor, - double> -{ - using TileShape = Shape<_128, _64, _16>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, // Atom - Layout>, // Atom layout - Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization - Layout,Stride<_2,_1>>, - Underscore>>; - - // A (M,K) M-Major - using SmemLayoutAtomA = decltype( - composition(Swizzle<2,2,2>{}, - Layout, - Stride< _1,_16>>{})); // M, K - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 2; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride< _1,_16>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 2x1 doubles - - // B (N,K) N-Major - using SmemLayoutAtomB = decltype( - composition(Swizzle<2,2,2>{}, - Layout, - Stride< _1,_16>>{})); // N, K - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 2; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride< _1,_16>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 2x1 doubles - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - double, TagToStrideA_t, - double, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - double, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// Ampere fp64 MMA TT (K-Major A and N-Major B) -template <> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm80, - double, cutlass::layout::RowMajor, - double, cutlass::layout::RowMajor, - double, cutlass::layout::ColumnMajor, - double> -{ - using TileShape = Shape<_128, _64, _16>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, // Atom - Layout>, // Atom layout - Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization - Layout,Stride<_2,_1>>, - Underscore>>; - - // A (M,K) K-Major - using SmemLayoutAtomA = decltype( - composition(Swizzle<2,0,4>{}, - Layout, - Stride<_1, _4>>{})); // M, K - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 1; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride<_16, _1>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 1x1 doubles - - // B (N,K) N-Major - using SmemLayoutAtomB = decltype( - composition(Swizzle<2,2,2>{}, - Layout, - Stride< _1,_16>>{})); // N, K - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 2; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, double>{}, // CopyAtom - Layout, - Stride< _1,_16>>{}, // ThrLayout for CopyAtom - Layout>{})); // Value layout: 2x1 doubles - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - double, TagToStrideA_t, - double, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - double, - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination, - cutlass::gemm::EpilogueDefault>; -}; - -/////////////////////////////////////////////////////////////////////////////// - -// Hopper fp64 MMA TN -template <> -struct DefaultGemmConfigurationToCutlass3Types< - arch::OpClassTensorOp, arch::Sm90, - double, cutlass::layout::RowMajor, - double, cutlass::layout::ColumnMajor, - double, cutlass::layout::ColumnMajor, - double> -{ - using TileShape = Shape<_128, _64, _16>; - static constexpr int ThreadCount = 128; - using DispatchPolicy = MainloopSm80CpAsync<3>; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>>; - - // A (M,K) K-major - using SmemLayoutAtomA = decltype( - make_ordered_layout(Shape<_128,_16>{}, - Step < _2, _1>{})); // M, K - using SmemCopyAtomA = Copy_Atom; - static constexpr int kAlignmentA = 2; - using GmemTiledCopyA = decltype( - make_tiled_copy(Copy_Atom, double>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // B (N,K) K-major - using SmemLayoutAtomB = decltype( - make_ordered_layout(Shape<_64,_16>{}, - Step < _2, _1>{})); // N, K - using SmemCopyAtomB = Copy_Atom; - static constexpr int kAlignmentB = 2; - using GmemTiledCopyB = decltype( - make_tiled_copy(Copy_Atom, double>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); - - // Mainloop - using CollectiveMainloop = collective::CollectiveMma< - DispatchPolicy, TileShape, - double, TagToStrideA_t, - double, TagToStrideB_t, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - // Epilogue - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - TileShape, Shape<_1,_1,_1>, - cutlass::epilogue::collective::EpilogueTileAuto, - double, double, - double, cutlass::layout::ColumnMajor, 1, - double, cutlass::layout::ColumnMajor, 1, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp deleted file mode 100644 index 89755dd7d3162b114a537e58c6aa33cac80078f9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp +++ /dev/null @@ -1,3993 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include // std::lcm - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gett.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/fusion/operations.hpp" -#include "cutlass/complex.h" -#include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" -#include "cutlass/detail/collective.hpp" - -#include "testbed_utils.h" - -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/gemm/gemm.h" - -#include "cute/int_tuple.hpp" -#include "cute/layout.hpp" -#include "cute/numeric/int.hpp" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -enum class ScalarLoc { - ON_HOST = 0, - ON_DEVICE = 1 -}; - -enum class VectorScale { - DISABLED = 0, - ENABLED = 1 -}; - -enum class CheckEquality { - EXACT = 0, - RELATIVE = 1 -}; - -namespace detail { - -inline constexpr auto decomp_mode_to_string = - [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode mode) -> std::string { - using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - if (mode == Mode::Heuristic) { - return "Heuristic"; - } - else if (mode == Mode::DataParallel) { - return "DataParallel"; - } - else if (mode == Mode::SplitK) { - return "SplitK"; - } - else if (mode == Mode::StreamK) { - return "StreamK"; - } - else { - return "Unknown"; - } - }; - -inline constexpr auto raster_order_to_string = - [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions mode) -> std::string { - using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; - if (mode == Mode::Heuristic) { - return "Heuristic"; - } - else if (mode == Mode::AlongM) { - return "AlongM"; - } - else if (mode == Mode::AlongN) { - return "AlongN"; - } - else { - return "Unknown"; - } - }; - -// Helper classes that take default data type when -// the Gemm::EpilogueOutputOp does not have ElementCompute -// and ElementScalar. -// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) -template -struct ElementComputeType { - using Type = Default; -}; - -template -struct ElementComputeType>> { - using Type = typename Gemm::EpilogueOutputOp::ElementCompute; -}; - -template -struct ElementScalarType { - using Type = Default; -}; - -template -struct ElementScalarType>> { - using Type = typename Gemm::EpilogueOutputOp::ElementScalar; -}; - - -template -struct IsF8F6F4Kernel { - static constexpr bool value = false; -}; - -template -struct IsF8F6F4Kernel> { - static constexpr bool value = true; -}; - - -template -struct IsSfdEpi : cute::false_type {}; - -template -struct IsSfdEpi> : cute::true_type {}; - -// The maximum swizzle size to use -// -// This class, like Splits above makes it harder to confuse -// the order of arguments of the various run(...) functions in this file. -class MaxSwizzleSize { -public: - MaxSwizzleSize() = default; - - template && - !cute::is_same_v)) > - explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} - explicit operator int() const { return max_swizzle_size_; } -private: - int max_swizzle_size_ = 1; -}; - -template -auto make_iterator(T* ptr) { - return cute::recast_ptr(ptr); -} - -template -struct IsDefaultEpilogue { - static constexpr bool value = false; -}; - -template -struct IsDefaultEpilogue> { - static constexpr bool value = true; -}; - -template -struct IsDefaultEpilogue> { - static constexpr bool value = true; -}; - -template -struct IsLegacyEpiloguePolicy { - static constexpr bool value = false; -}; - -template -struct IsLegacyEpiloguePolicy> { - using EpiloguePolicy = typename Epilogue::DispatchPolicy; - static constexpr bool value = cute::is_same_v< - EpiloguePolicy, - cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< - EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize>>; -}; - -// The number of splits to test. -// -// This class makes it harder to confuse the order of arguments -// of the various run(...) functions in this file. The constructor -// is explicit, so one can't just type 42 (or false, which the -// compiler unhelpfully turns into 0); one has to type Splits(42). -// Splits() picks the default number of splits, 1. -// -// The conversion-to-int operator (operator int()) MUST be explicit! -// Conversion to int MUST require static_cast. -// Otherwise, that defeats a key purpose of this class, -// which is to catch common errors of confusing the order -// of function arguments. -class Splits { -public: - Splits() = default; - - template && - !cute::is_same_v)) > - explicit Splits(IntegralNotBool splits) : splits_(splits) {} - explicit operator int() const { return splits_; } -private: - int splits_ = 1; -}; - -// The number of iterations to test. -// -// This class, like Splits above makes it harder to confuse -// the order of arguments of the various run(...) functions in this file. -// Iterations() picks the default number of iterations, 20. -class Iterations { -public: - Iterations() = default; - - template && - !cute::is_same_v)) > - explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} - explicit operator int() const { return iterations_; } -private: - int iterations_ = 20; -}; - -template -bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } - - else if (bits_input <= 6) { - scope_max = 2; - scope_min = -2; - } - - else if (bits_input <= 8) { - - if constexpr ( - cute::is_same_v){ - scope_max = 4; - scope_min = 1; - } - else { - - scope_max = 1; - scope_min = -1; - - } - - } - else{ - scope_max = 4; - scope_min = -4; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - - else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - } - - else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - - else if (dist_kind == cutlass::Distribution::AllOnes) { - cutlass::reference::host::TensorFill(view, Element(1)); - } - - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; -} - -// Looks at Cute Stride to check Row / Column Major -template -static constexpr bool is_row_or_col_major(){ - int stride_0 = int(cute::size<0>(Stride{})); - int stride_1 = int(cute::size<1>(Stride{})); - int depth = cute::depth(Stride{}); - return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); -} - - -// -// Default MMA input Operands : A , B -// -template< - class ScheduleType_, - class Gemm, - class ElementA_ = typename Gemm::GemmKernel::ElementA, - class ElementB_ = typename Gemm::GemmKernel::ElementB, - class Enable = void> -struct HostCollectiveMainloop { - // Kernel data types - using ElementA = ElementA_; - using StrideA = typename Gemm::GemmKernel::StrideA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - - cutlass::ComplexTransform TransformA = Gemm::kTransformA; - cutlass::ComplexTransform TransformB = Gemm::kTransformB; - - StrideA stride_a; - StrideB stride_b; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() - ): - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - init_A(init_A_), init_B(init_B_), seed(seed_), - check_relative_equality(check_relative_equality_) { } - - template - bool initialize(ProblemShapeType problem_size) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop (generic)::initialize(problem_shape)"); -#endif - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - - stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto a_coord = cutlass::make_Coord(M * L, K); - // Cutlass has Row/Col major refers to MxK times KxN matrix product, - // so the HostTensorB should be treated as KxN in "coord"'s view - auto b_coord = cutlass::make_Coord(K, N * L); - - try { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.resize"); -#endif - tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.resize"); -#endif - tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an unknown exception"); - throw; - } - - try { - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); - } - catch (cutlass::cuda_exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw cutlass::cuda_exception: " << e); - throw; - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked_initialize_tensor threw an unknown exception"); - throw; - } - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = ElementA(1); - tensor_B.host_view().at({0, 0}) = ElementB(1); - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Check last error before sync_device()"); - cudaError_t error = cudaGetLastError(); - const auto error_str = cudaGetErrorString(error); - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: cudaGetLastError() is " << error_str); - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.host_data()=" << tensor_A.host_data() << ", tensor_A.device_data()=" << tensor_A.device_data()); - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.host_data()=" << tensor_B.host_data() << ", tensor_B.device_data()=" << tensor_B.device_data()); - } -#endif - try { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.sync_device"); -#endif - tensor_A.sync_device(); -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.sync_device"); -#endif - tensor_B.sync_device(); - } - catch (cutlass::cuda_exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw cutlass::cuda_exception: " << e); - throw; - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an unknown exception"); - throw; - } - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Reached end"); -#endif - return true; - } - - Arguments to_args() { - - - // Runtime datatype selection - if constexpr (not cute::is_same_v) { - using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; - return { - reinterpret_cast(tensor_A.device_data()), stride_a, - reinterpret_cast(tensor_B.device_data()), stride_b - }; - } - else { - - Arguments arguments = - { - tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b - }; - return arguments; - } - } - - auto to_host_args(ProblemShapeType problem_size) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - auto A = make_tensor(make_iterator(tensor_A.host_data()), - make_layout(make_shape(M, K, L), stride_a)); - auto B = make_tensor(make_iterator(tensor_B.host_data()), - make_layout(make_shape(N, K, L), stride_b)); - - - auto dummy_SFA = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, K, L), stride_a)); - auto dummy_SFB = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(N, K, L), stride_b)); - - cutlass::reference::host::GettMainloopParams mainloop_params{}; - - mainloop_params.A = A; - mainloop_params.B = B; - mainloop_params.transform_A = TransformA; - mainloop_params.transform_B = TransformB; - - return mainloop_params; - } - - void print_tensors(std::ofstream& file) { - file << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view(); - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - cute::Shape problem_shape_MNKL) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - - bool passed = true; - return passed; - } -}; - -// -// Sparse MMA host implementation -// -template< - class Gemm, - class ElementA_, - class ElementB_> -struct HostCollectiveMainloopSparse -{ - - // Kernel data types - using ElementA = ElementA_; - // CuTe layout A for the kernel's sparse tensorA. - using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - - using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; - // CuTe layout E for the kernel's metadata tensor. - using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; - - // The following typenames are for the reference host tensors. They are non-sparse tensors. - using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); - using StrideA = cutlass::gemm::TagToStrideA_t; - // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. - using StrideE = StrideA; - - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; - - using ArchTag = typename Gemm::ArchTag; - - using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< - cute::Shape, - ElementA, - LayoutTagA, - SparseConfig>; - - using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< - cute::Shape, - ElementA, - LayoutTagA, - SparseConfig, - ArchTag>; - - using Compressor = cutlass::transform::device::TransformUniversalAdapter; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - StrideA stride_a; - StrideA stride_a_compressed; - StrideB stride_b; - StrideE stride_e; - - LayoutA layout_a; - LayoutE layout_e; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - typename LayoutTagE::Stride stride_factor_E; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_A_Comp; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_E; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - static constexpr int MaxSmCount = 16; - - HostCollectiveMainloopSparse( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), - typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() - ): - check_relative_equality(check_relative_equality_), - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - stride_factor_E(stride_factor_E_), - init_A(init_A_), init_B(init_B_), seed(seed_) { } - - template - bool initialize(ProblemShapeType problem_size) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloopSparse::initialize"); -#endif - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - - stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - - CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); - - // TensorE - // In unit of ElementE (uint8_t), after alignment requirement - // M-dim: TensorEAtom_M alignment - // K-dim: TensorEAtom_K alignment - int KAlignedE = compressor_utility.get_metadata_k_physical(); - int MAlignedE = compressor_utility.get_metadata_m_physical(); - - // TensorA Compressed - // In unit of ElementARaw, after alignment requirement - // M-dim: TMA alignment - // K-dim: TMA alignment - int KAlignedAC = compressor_utility.get_tensorA_k_physical(); - int MAlignedAC = compressor_utility.get_tensorA_m_physical(); - - stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); - stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); - - auto a_coord = cutlass::make_Coord(M * L, K); - auto b_coord = cutlass::make_Coord(K, N * L); - auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); - auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); - - tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); - tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); - tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = ElementA(1); - tensor_B.host_view().at({0, 0}) = ElementB(1); - - compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_E.sync_device(); - tensor_A_Comp.sync_device(); - - cutlass::Status status {cutlass::Status::kSuccess }; - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Compressor::Arguments arguments{ - {M, N, K, L}, - {tensor_A.device_data(), - stride_a, - tensor_A_Comp.device_data(), - tensor_E.device_data()}, - {hw_info} - }; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - status = compressor_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - return false; - } - - status = compressor_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - return false; - } - - status = compressor_op.run(); - - auto result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); - layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); - - tensor_E.sync_host(); - tensor_A_Comp.sync_host(); - - return true; - } - - Arguments to_args() { - using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; - return { - reinterpret_cast(tensor_A_Comp.device_data()), layout_a, - reinterpret_cast(tensor_B.device_data()), stride_b, - tensor_E.device_data(), layout_e - }; - } - - auto to_host_args(ProblemShapeType problem_size) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - auto A = make_tensor(make_iterator(tensor_A.host_data()), - make_layout(make_shape(M, K, L), stride_a)); - auto B = make_tensor(make_iterator(tensor_B.host_data()), - make_layout(make_shape(N, K, L), stride_b)); - - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - return mainloop_params; - } - - void print_tensors(std::ofstream& file) { - file << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view(); - } - - bool compare_reference( - cute::Shape problem_shape_MNKL) { - auto [M, N, K, L] = problem_shape_MNKL; - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - return true; - } -}; - -template< - class ScheduleType_, - class Gemm, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - typename Gemm::CollectiveMainloop::DispatchPolicy>>> - : HostCollectiveMainloopSparse -{ - using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; -}; - -// -// Sparse MMA input Operands : A_compressed, B, metadata -// -// Structured Sparse Gemm Input Operands - -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - typename ElementA_, - typename ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> - : HostCollectiveMainloopSparse -{ - using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; -}; - -// -// Sparse Gemm Input Operands : A , B, E -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_ >; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), - typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, - stride_factor_B_, - stride_factor_E_) {} -}; - -// -// Sparse Gemm Input Operands : A , B, E -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_ >; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), - typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, - stride_factor_B_, - stride_factor_E_) {} -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - // Kernel data types - using ElementA = ElementA_; - using StrideA = typename Gemm::GemmKernel::StrideA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - - static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; - - using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; - using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; - using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; - using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - StrideA stride_a; - StrideB stride_b; - - LayoutSFA layout_sfa; - LayoutSFB layout_sfb; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_SFA; - cutlass::HostTensor tensor_SFB; - - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() - ): - check_relative_equality(check_relative_equality_), - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - init_A(init_A_), init_B(init_B_), seed(seed_) { } - - template - bool initialize(ProblemShapeType problem_size) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); -#endif - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - - stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto a_coord = cutlass::make_Coord(M * L, K); - // Cutlass has Row/Col major refers to MxK times KxN matrix product, - // so the HostTensorB should be treated as KxN in "coord"'s view - auto b_coord = cutlass::make_Coord(K, N * L); - - tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); - tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = ElementA(1); - tensor_B.host_view().at({0, 0}) = ElementB(1); - - tensor_A.sync_device(); - tensor_B.sync_device(); - - using namespace cute; - auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); - auto m_blks = cutlass::ceil_div(M, Blk_MN{}); - auto n_blks = cutlass::ceil_div(N, Blk_MN{}); - layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); - layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); - auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); - - tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); - tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); - - EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); - EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_SFA.host_view().at({0, 0}) = ElementSF(1); - tensor_SFB.host_view().at({0, 0}) = ElementSF(1); - - tensor_SFA.sync_device(); - tensor_SFB.sync_device(); - - return true; - } - - Arguments to_args() { - using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; - return { - reinterpret_cast(tensor_A.device_data()), stride_a, - reinterpret_cast(tensor_B.device_data()), stride_b, - tensor_SFA.device_data(), layout_sfa, - tensor_SFB.device_data(), layout_sfb - }; - } - - auto to_host_args(ProblemShapeType problem_size) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - auto A = make_tensor(make_iterator(tensor_A.host_data()), - make_layout(make_shape(M, K, L), stride_a)); - auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); - - auto B = make_tensor(make_iterator(tensor_B.host_data()), - make_layout(make_shape(N, K, L), stride_b)); - auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); - - cutlass::reference::host::GettMainloopParams - mainloop_params{A, SfA, B, SfB}; - return mainloop_params; - } - - void print_tensors(std::ofstream& file) { - file << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nSFA =\n" << tensor_SFA.host_view() - << "\nSFB =\n" << tensor_SFB.host_view(); - } - - bool compare_reference( - cute::Shape problem_shape_MNKL) { - auto [M, N, K, L] = problem_shape_MNKL; - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); - return true; - } -}; - - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -// -// Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - typename ElementA_, - typename ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - // Kernel data types - using ElementA = ElementA_; - // CuTe layout A for the kernel's sparse tensorA. - using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - - using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; - // CuTe layout E for the kernel's metadata tensor. - using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; - - // The following typenames are for the reference host tensors. They are non-sparse tensors. - using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); - using StrideA = cutlass::gemm::TagToStrideA_t; - // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. - using StrideE = StrideA; - - static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - - using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; - - using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; - using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; - using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; - using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; - - using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< - cute::Shape, - ElementA, - LayoutTagA, - SparseConfig>; - using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< - cute::Shape, - ElementA, - LayoutTagA, - SparseConfig, - cutlass::arch::Sm100>; - - using Compressor = cutlass::transform::device::TransformUniversalAdapter; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - StrideA stride_a; - StrideA stride_a_compressed; - StrideB stride_b; - StrideE stride_e; - - LayoutA layout_a; - LayoutE layout_e; - LayoutSFA layout_sfa; - LayoutSFB layout_sfb; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - typename LayoutTagE::Stride stride_factor_E; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_A_Comp; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_E; - cutlass::HostTensor tensor_SFA; - cutlass::HostTensor tensor_SFB; - - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), - typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() - ): - check_relative_equality(check_relative_equality_), - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - stride_factor_E(stride_factor_E_), - init_A(init_A_), init_B(init_B_), seed(seed_) { } - - template - bool initialize(ProblemShapeType problem_size) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelSparseTmaWarpSpecializedBlockScaledSm100)::initialize"); -#endif - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - - stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - - CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); - - // TensorE - // In unit of ElementE (uint8_t), after alignment requirement - // M-dim: TensorEAtom_M alignment - // K-dim: TensorEAtom_K alignment - int KAlignedE = compressor_utility.get_metadata_k_physical(); - int MAlignedE = compressor_utility.get_metadata_m_physical(); - - // TensorA Compressed - // In unit of ElementARaw, after alignment requirement - // M-dim: TMA alignment - // K-dim: TMA alignment - int KAlignedAC = compressor_utility.get_tensorA_k_physical(); - int MAlignedAC = compressor_utility.get_tensorA_m_physical(); - - stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); - stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); - - auto a_coord = cutlass::make_Coord(M * L, K); - auto b_coord = cutlass::make_Coord(K, N * L); - auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); - auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); - - tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); - tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); - tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = ElementA(1); - tensor_B.host_view().at({0, 0}) = ElementB(1); - - compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_E.sync_device(); - tensor_A_Comp.sync_device(); - - cutlass::Status status {cutlass::Status::kSuccess }; - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Compressor::Arguments arguments{ - {M, N, K, L}, - {tensor_A.device_data(), - stride_a, - tensor_A_Comp.device_data(), - tensor_E.device_data()}, - {hw_info} - }; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - status = compressor_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - return false; - } - - status = compressor_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - return false; - } - - status = compressor_op.run(); - - auto result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); - layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); - - tensor_E.sync_host(); - tensor_A_Comp.sync_host(); - - using namespace cute; - auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); - auto m_blks = cutlass::ceil_div(M, Blk_MN{}); - auto n_blks = cutlass::ceil_div(N, Blk_MN{}); - layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); - layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); - auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); - - tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); - tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); - - EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); - EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_SFA.host_view().at({0, 0}) = ElementSF(1); - tensor_SFB.host_view().at({0, 0}) = ElementSF(1); - - tensor_SFA.sync_device(); - tensor_SFB.sync_device(); - - return true; - } - - Arguments to_args() { - using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; - return { - reinterpret_cast(tensor_A_Comp.device_data()), layout_a, - reinterpret_cast(tensor_B.device_data()), stride_b, - tensor_E.device_data(), layout_e, - tensor_SFA.device_data(), layout_sfa, - tensor_SFB.device_data(), layout_sfb - }; - } - - auto to_host_args(ProblemShapeType problem_size) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - auto A = make_tensor(make_iterator(tensor_A.host_data()), - make_layout(make_shape(M, K, L), stride_a)); - auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); - - auto B = make_tensor(make_iterator(tensor_B.host_data()), - make_layout(make_shape(N, K, L), stride_b)); - auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); - - // return {A, SfA, B, SfB}; - cutlass::reference::host::GettMainloopParams - mainloop_params{A, SfA, B, SfB}; - return mainloop_params; - } - - void print_tensors(std::ofstream& file) { - file << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nSFA =\n" << tensor_SFA.host_view() - << "\nSFB =\n" << tensor_SFB.host_view(); - } - - bool compare_reference( - cute::Shape problem_shape_MNKL) { - auto [M, N, K, L] = problem_shape_MNKL; - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); - return true; - } -}; - -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), - typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, - stride_factor_B_, - stride_factor_E_) {} -}; - -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), - typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, - stride_factor_B_, - stride_factor_E_) {} -}; - -template -struct HostCollectiveDefaultEpilogue { - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using kernel = typename Gemm::GemmKernel; - using Epilogue = typename kernel::CollectiveEpilogue; - - using ElementD = typename kernel::ElementD; - using StrideD = typename kernel::StrideD; - using ElementC = non_void_t; - using StrideC = typename kernel::StrideC; - - using FusionOp = typename Gemm::EpilogueOutputOp; - - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static_assert(is_row_or_col_major(), - "ERROR : C Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : D Layout is neither Row / Column Major)"); - - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementAccumulator = typename kernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename kernel::ProblemShape; - using ElementCompute = typename ElementComputeType::Type; - using ElementScalar = typename ElementScalarType::Type; - - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; - - /// Initialization - StrideC stride_c; - StrideD stride_d; - - typename LayoutTagC::Stride stride_factor_C; - typename LayoutTagD::Stride stride_factor_D; - - cutlass::HostTensor tensor_C; - // Inputs - ElementScalar alpha; - ElementScalar beta; - - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - // Are scalars copied to device memory before kernel launch - ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; - // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector - VectorScale vector_scale_mode = VectorScale::DISABLED; - - cutlass::Distribution::Kind init_C; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - HostCollectiveDefaultEpilogue( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), - stride_factor_D(typename LayoutTagD::Stride()), - check_relative_equality(check_relative_equality_), - use_device_scalars(use_device_scalars_){ } - - bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize(problem_size, alpha, beta)"); -#endif - // Initialize Epilogue tensors - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto c_coord = cutlass::make_Coord(M * L, N); - try { - tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); - tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); - reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an unknown exception"); - throw; - } - { - const bool init_succeeded = initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); - if (not init_succeeded) { - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: initialize_tensor returned false"); - } - EXPECT_TRUE(init_succeeded); - } - tensor_C.host_view().at({0, 0}) = ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - try { - tensor_C.sync_device(); - tensor_D.sync_device(); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an unknown exception"); - throw; - } - - alpha = alpha_; - beta = beta_; - - return true; - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta) { - auto [M, N, K, L] = problem_shape_MNKL; - - tensor_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - } - - if (reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - } - - bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); - if(!passed) { - std::cout<<"D is incorrect"<(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto K = cute::get<2>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - auto coord_0 = cutlass::make_Coord(0); - auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D)> - epilogue_params{}; - - epilogue_params.C = C; - epilogue_params.D = D; - epilogue_params.alpha = alpha; - epilogue_params.beta = beta; - - return epilogue_params; - } -}; - -template -struct HostCollectiveEpilogue { - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using kernel = typename Gemm::GemmKernel; - using Epilogue = typename kernel::CollectiveEpilogue; - static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); - - using ElementD = typename kernel::ElementD; - using StrideD = typename kernel::StrideD; - using ElementC = non_void_t; - using StrideC = typename kernel::StrideC; - - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static_assert(is_row_or_col_major(), - "ERROR : C Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : D Layout is neither Row / Column Major)"); - - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementAccumulator = typename kernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename kernel::ProblemShape; - - // - // FusionOperation derived types/queries - // - static constexpr bool IsLegacy = detail::IsLegacyEpiloguePolicy::value; - - // FFMA2 SGEMM uses ThreadEpilogueOp for bias and relu support instead of FusionOp, so we compose LinCombPerRowBiasEltAct FusionOp by hand to test the functionality. - static constexpr bool IsFfma2Kernel = cute::is_same_v; - using FusionOp = cute::conditional_t, - typename Gemm::EpilogueOutputOp>; - static_assert(cute::is_base_of_v); - - - // Scale factor Generation related - using SfStrategy = cutlass::reference::host::SfStrategy; - static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; - static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; - static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; - static constexpr bool IsKMajorSFD = cute::is_same_v; - using ElementSFD = non_void_t; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; - using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; - using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; - cutlass::HostTensor tensor_SFD; - cutlass::HostTensor reference_SFD; - - using ElementCompute = typename FusionOp::ElementCompute; - using ElementScalar = typename FusionOp::ElementScalar; - using ElementBias = non_void_t; - using ElementAux = non_void_t; - using ElementAmax = non_void_t; - using LayoutTagAux = non_void_t; - using ActivationFunctor = non_void_t>; - - static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; - static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; - static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); - - static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; - static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; - static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; - static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; - static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; - static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; - static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && - (cute::is_same_v || - cute::is_same_v); - static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && - (cute::is_same_v || - cute::is_same_v); - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; - - /// Initialization - StrideC stride_c; - StrideD stride_d; - - typename LayoutTagC::Stride stride_factor_C; - typename LayoutTagD::Stride stride_factor_D; - - // Inputs - cutlass::HostTensor alpha; - cutlass::HostTensor beta; - cutlass::HostTensor scale_A; - cutlass::HostTensor scale_B; - cutlass::HostTensor scale_C; - cutlass::HostTensor scale_D; - cutlass::HostTensor scale_Aux; - cutlass::HostTensor bias; - cutlass::HostTensor tensor_C; - cutlass::HostTensor norm_constant; - - // Outputs - cutlass::HostTensor abs_max_Aux; - cutlass::HostTensor abs_max_D; - cutlass::HostTensor tensor_Aux; - cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // References - cutlass::HostTensor reference_dbias; - cutlass::HostTensor reference_Aux; - cutlass::HostTensor reference_abs_max_Aux; - cutlass::HostTensor reference_abs_max_D; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - // Are scalars copied to device memory before kernel launch - ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; - // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector - VectorScale vector_scale_mode = VectorScale::DISABLED; - - // Random distribution with which to initialize the A/B/C/D/Aux scaling factors - cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; - // Random distribution with which to initialize the bias vector - cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_C; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - HostCollectiveEpilogue( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): init_scale(init_scale_), init_bias(init_bias_), - init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), - stride_factor_D(typename LayoutTagD::Stride()), - check_relative_equality(check_relative_equality_), - use_device_scalars(use_device_scalars_){ } - - bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize(problem_size, alpha, beta)"); -#endif - // Initialize Epilogue tensors - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::size<0>(problem_shape_MNKL); - auto N = cute::size<1>(problem_shape_MNKL); - auto K = cute::size<2>(problem_shape_MNKL); - auto L = cute::size<3>(problem_shape_MNKL); - - stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto c_coord = cutlass::make_Coord(M * L, N); - try { - tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); - tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); - reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an unknown exception"); - throw; - } - - try { - const bool initialize_tensor_C_succeeded = - initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); - if (not initialize_tensor_C_succeeded) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor returned false"); - } - EXPECT_TRUE(initialize_tensor_C_succeeded); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an unknown exception"); - throw; - } - - tensor_C.host_view().at({0, 0}) = ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - try { - tensor_C.sync_device(); - tensor_D.sync_device(); - } - catch (std::exception const& e) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an unknown exception"); - throw; - } - - auto scalar_coord = cutlass::make_Coord(1); - auto col_vector_coord = cutlass::make_Coord(M); - auto row_vector_coord = cutlass::make_Coord(N); - auto batch_vector_coord = cutlass::make_Coord(L); - if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { - // scalars - if (vector_scale_mode == VectorScale::DISABLED) { - // batched scalars - if (use_device_scalars == ScalarLoc::ON_DEVICE) { - alpha.resize(batch_vector_coord, true); - beta.resize(batch_vector_coord, true); - EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); - if (beta_ != ElementScalar(0)) { - EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); - } - else { - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - } - // non-batched scalars - else { - alpha.resize(scalar_coord, false); - beta.resize(scalar_coord, false); - cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - } - // batched vectors - else { - auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); - alpha.resize(batched_vector_coord, true); - beta.resize(batched_vector_coord, true); - EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); - if (beta_ != ElementScalar(0)) { - EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); - } - else { - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - } - } - else { - if (use_device_scalars == ScalarLoc::ON_DEVICE) { - // Set alpha beta for different batches. - alpha.resize(batch_vector_coord, true); - beta.resize(batch_vector_coord, true); - cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); - for (int l = 0; l < L; ++l) { - beta.host_view().at(cutlass::make_Coord(l)) = beta_ + ElementScalar(l); - } - } - else { - alpha.resize(scalar_coord, false); - beta.resize(scalar_coord, false); - cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - } - alpha.sync_device(); - beta.sync_device(); - - if constexpr (IsScaleFactorEnabled) { - scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); - EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); - EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); - EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); - scale_A.sync_device(); - scale_B.sync_device(); - scale_C.sync_device(); - scale_D.sync_device(); - } - - if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { - bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); - EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); - bias.sync_device(); - } - - if constexpr (IsDeBiasEnabled) { - bias.resize(col_vector_coord); - reference_dbias.resize(col_vector_coord); - cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); - cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); - bias.sync_device(); - } - - if constexpr (IsAbsMaxEnabledD) { - abs_max_D.resize(scalar_coord); - // ensure in-place device reductions perform their own initialization - cutlass::reference::host::TensorFill(abs_max_D.host_view(), - CUTLASS_STL_NAMESPACE::numeric_limits::max()); - abs_max_D.sync_device(); - reference_abs_max_D.resize(scalar_coord); - cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); - } - - if constexpr (IsAuxInEnabled) { - auto aux_coord = cutlass::make_Coord(M * L, N); - auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); - tensor_Aux.resize(aux_coord, aux_layout); - EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); - tensor_Aux.sync_device(); - stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); - } - - if constexpr (IsAuxOutEnabled) { - auto aux_coord = cutlass::make_Coord(M * L, N); - auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); - tensor_Aux.resize(aux_coord, aux_layout); - reference_Aux.resize(aux_coord, aux_layout, false); - tensor_Aux.sync_device(); - stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); - - if constexpr (IsScaleFactorEnabled) { - scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); - scale_Aux.sync_device(); - } - - if constexpr (IsAbsMaxEnabledAux) { - abs_max_Aux.resize(scalar_coord); - // ensure in-place device reductions perform their own initialization - cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), - CUTLASS_STL_NAMESPACE::numeric_limits::max()); - abs_max_Aux.sync_device(); - reference_abs_max_Aux.resize(scalar_coord); - cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); - } - } - - - if constexpr (IsBlockScaleSupported) { - auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); - auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); - auto sfd_coord = [&] () { - if constexpr (IsKMajorSFD) { - return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); - } - else { - return cutlass::make_Coord(m_blks * Blk_SF{} * L, n_blks * Blk_MN{}); - } - }(); - tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); - reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); - tensor_SFD.sync_device(); - norm_constant.resize(scalar_coord, true); - EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); - norm_constant.sync_device(); - } - - - return true; - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta) { - tensor_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - } - - if (reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - } - - bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); - if(!passed) { - #if 0 - auto [M, N, K, L] = problem_shape_MNKL; - auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { - printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); - } - } - } - } - #endif - std::cout<<"D is incorrect"<(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - Arguments arguments = - { - {}, - tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d - }; - - auto &fusion_args = arguments.thread; - if constexpr (IsLegacy) { - arguments.thread = { - alpha.at(coord_0), - beta.at(coord_0), - alpha.device_data(), - beta.device_data() - }; - arguments.ptr_Bias = bias.device_data(); - arguments.ptr_T = tensor_Aux.device_data(); - } - else { - fusion_args.alpha = alpha.at(coord_0); - fusion_args.alpha_ptr = alpha.device_data(); - // Only initializing beta/beta_ptr for non-void source - if constexpr (not cute::is_void_v) { - fusion_args.beta = beta.at(coord_0); - fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr - } - - if constexpr (IsPerRowScaleEnabled) { - int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); - fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); - } - else if constexpr (IsPerColScaleEnabled) { - int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); - fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); - } - else { - if constexpr (not IsFfma2Kernel) { - if (use_device_scalars == ScalarLoc::ON_DEVICE) { - if (L > 1) { - fusion_args.dAlpha = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); - fusion_args.dBeta = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); - } - } - } - } - - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_a = scale_A.at(coord_0); - fusion_args.scale_b = scale_B.at(coord_0); - fusion_args.scale_c = scale_C.at(coord_0); - fusion_args.scale_d = scale_D.at(coord_0); - fusion_args.scale_a_ptr = scale_A.device_data(); - fusion_args.scale_b_ptr = scale_B.device_data(); - fusion_args.scale_c_ptr = scale_C.device_data(); - fusion_args.scale_d_ptr = scale_D.device_data(); - } - - if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { - fusion_args.bias_ptr = bias.device_data(); - } - - if constexpr (IsDeBiasEnabled) { - fusion_args.dbias_ptr = bias.device_data(); - } - - // example of how to set kernel activation arguments - // see ActivationFunctor::Arguments in activation.h for definition - // if Arguments doesn't exist then fusion_args.activation is empty - auto init_activation_args = [] (auto activation, auto& args) { - using Activation = cute::remove_cvref_t; - if constexpr (cute::is_same_v>) { - args.lower_bound = 0; // Treat Clamp as ReLU - args.upper_bound = cutlass::platform::identity_for_minimum(); - } - if constexpr (cute::is_same_v>) { - args.scale = ElementCompute(1); - } - }; - - if constexpr (not cute::is_same_v>) { - init_activation_args(ActivationFunctor{}, fusion_args.activation); - } - if constexpr (IsAbsMaxEnabledD) { - fusion_args.amax_D_ptr = abs_max_D.device_data(); - } - - if constexpr (IsAuxInEnabled) { - fusion_args.aux_ptr = tensor_Aux.device_data(); - fusion_args.dAux = stride_Aux; - } - - if constexpr (IsAuxOutEnabled) { - fusion_args.aux_ptr = tensor_Aux.device_data(); - fusion_args.dAux = stride_Aux; - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_aux = scale_Aux.at(coord_0); - fusion_args.scale_aux_ptr = scale_Aux.device_data(); - } - if constexpr (IsAbsMaxEnabledAux) { - fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); - } - } - - - if constexpr (IsBlockScaleSupported) { - arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); - arguments.thread.norm_constant_ptr = norm_constant.device_data(); - } - } - - return arguments; - } - - auto to_host_args(ProblemShapeType problem_size) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto K = cute::get<2>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - auto coord_0 = cutlass::make_Coord(0); - auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), - cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); - auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), - cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); - auto Valpha = [&](){ - if constexpr (IsPerRowScaleEnabled) { - int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - return cute::make_tensor(detail::make_iterator(alpha.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); - } - else if constexpr (IsPerColScaleEnabled) { - int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - return cute::make_tensor(detail::make_iterator(alpha.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); - } - else { - return cute::make_tensor(detail::make_iterator(alpha.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); - } - }(); - - auto Vbeta = [&]() { - if constexpr (IsPerRowScaleEnabled) { - int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - return cute::make_tensor(detail::make_iterator(beta.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); - } - else if constexpr (IsPerColScaleEnabled) { - int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; - int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); - return cute::make_tensor(detail::make_iterator(beta.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); - } - else { - return cute::make_tensor(detail::make_iterator(beta.host_data()), - cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); - } - }(); - - auto SfD = [&](){ - if constexpr (IsBlockScaleSupported) { - auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), - Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); - return tensor; - } - else { - // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. - return D; - } - }(); - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D), - decltype(Bias), - decltype(Aux), - decltype(Valpha), - decltype(Vbeta), - ActivationFunctor, - decltype(SfD), - Int, - cutlass::plus, - IsColBiasEnabled - , SfGenStrategy - > epilogue_params{}; - - epilogue_params.C = C; - epilogue_params.D = D; - epilogue_params.alpha = alpha.at(coord_0); - epilogue_params.beta = beta.at(coord_0); - - if constexpr (IsScaleFactorEnabled) { - epilogue_params.scale_a = scale_A.at(coord_0); - epilogue_params.scale_b = scale_B.at(coord_0); - epilogue_params.scale_c = scale_C.at(coord_0); - epilogue_params.scale_d = scale_D.at(coord_0); - } - - if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) - { - epilogue_params.Bias = Bias; - } - - if constexpr (IsAbsMaxEnabledD) { - epilogue_params.abs_max_D = reference_abs_max_D.host_data(); - } - - if constexpr (IsAuxInEnabled) { - epilogue_params.Aux = Aux; - } - - if constexpr (IsAuxOutEnabled) { - epilogue_params.Aux = Aux; - if constexpr (IsScaleFactorEnabled) { - epilogue_params.scale_aux = scale_Aux.at(coord_0); - } - if constexpr (IsAbsMaxEnabledAux) { - epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); - } - } - - if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { - epilogue_params.Valpha = Valpha; - if (vector_scale_mode == VectorScale::ENABLED) { - epilogue_params.Vbeta = Vbeta; - } - } - else { - if (use_device_scalars == ScalarLoc::ON_DEVICE) { - epilogue_params.Valpha = Valpha; - epilogue_params.Vbeta = Vbeta; - } - } - - if constexpr (IsBlockScaleSupported) { - epilogue_params.SfD = SfD; - epilogue_params.st = norm_constant.at(coord_0); - } - return epilogue_params; - } -}; - -template < - typename Gemm, - template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - bool force_legacy_epilogue = false, - typename ElementA = typename Gemm::GemmKernel::ElementA, - typename ElementB = typename Gemm::GemmKernel::ElementB - , typename RuntimeDatatypeA = void* - , typename RuntimeDatatypeB = void* -> -struct TestbedImpl { - // Kernel data types - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type - using HostCollectiveMainloopType = HostCollectiveMainloop; - - using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, - HostCollectiveDefaultEpilogue, - HostCollectiveEpilogue>; - - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementCompute = typename ElementComputeType::Type; - using ElementScalar = typename ElementScalarType::Type; - - using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; - using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; - using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; - using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; - - - using InternalElementA = typename Gemm::GemmKernel::ElementA; - using InternalElementB = typename Gemm::GemmKernel::ElementB; - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - - - uint32_t sm_count; - // Used to force multi-wave tests for persistent kernel schedules - constexpr static int MaxSmCount = 16; - static constexpr uint64_t kDefaultSeed = 4096; - static constexpr uint32_t mma_promotion_interval = 4; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - HostCollectiveMainloopType collective_mma_inputs; - CollectiveEpilogue collective_epilogue; - - // - // Methods - // - - TestbedImpl( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), - collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } - - TestbedImpl( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), - collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } - - /// Initializes data structures - bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::initialize(problem_size, alpha, beta)"); -#endif - collective_mma_inputs.initialize(problem_size); - collective_epilogue.initialize(problem_size, alpha_, beta_); - - return true; - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta) - { - auto [M, N, K, L] = problem_shape_MNKL; - - bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); - passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); - EXPECT_TRUE(passed); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; - - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - collective_mma_inputs.print_tensors(file); - collective_epilogue.print_tensors(file); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - ProblemShapeType problem_size, - ElementScalar alpha, - ElementScalar beta) - { - using namespace cute; - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); - auto epilogue_params = collective_epilogue.to_host_args(problem_size); - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - - bool passed = compare_reference(problem_shape_MNKL, alpha, beta); - return passed; - } - - /// Determine if the CUDA device is sufficient to run the kernel - bool sufficient() { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); - - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - this->sm_count = properties.multiProcessorCount; - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - printf("failed due to smem_size\n"); - printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); - return false; - } - - return true; - } - - bool profile( - ProblemShapeType problem_size, - int iterations, - Gemm& gemm_op, - typename Gemm::Arguments& arguments, - cutlass::device_memory::allocation& workspace) { - int M = cute::size<0>(problem_size); - int N = cute::size<1>(problem_size); - int K = cute::size<2>(problem_size); - int L = 1; - if constexpr(cute::rank(ProblemShapeType{}) == 4) { - L = cute::size<3>(problem_size); - } - - - cutlass::Status status; - // - // Run the GEMM - // - cudaError_t result; - - for (int iter = 0; iter < iterations; ++iter) { - status = gemm_op(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - return false; - } - } - - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - return true; - } - - /// Executes one test - bool run( - ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - detail::Iterations iterations = detail::Iterations{}, - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic - , RuntimeDatatypeA runtime_input_datatype_a = {} - , RuntimeDatatypeB runtime_input_datatype_b = {} - ) - { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run"); -#endif - - // Fail test if insufficient CUDA device - if (!sufficient()) { - CUTLASS_TRACE_HOST("TestbedImpl::run: Test failed due to insufficient CUDA device"); - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - CUTLASS_TRACE_HOST("TestbedImpl::run: sufficient() returned true"); - } -#endif - - try { - const bool initialized = this->initialize(problem_size, alpha, beta); - if (not initialized) { - CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize returned false"); - std::cerr << "Initialization failed \n"; - return false; - } - } - catch ([[maybe_unused]] std::exception const& e) { - CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); - throw; - } - catch (...) { - CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an unknown exception"); - throw; - } - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize() returned true"); -#endif - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - if (not profiling) { - this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = this->sm_count; - } - else { - this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = this->sm_count; - } - - typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; - if constexpr (cute::is_same_v) { - scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; - } - else { - scheduler_args = { static_cast(max_swizzle), raster_order }; - } - typename HostCollectiveMainloopType::Arguments mainloop_args; - - mainloop_args = collective_mma_inputs.to_args(); - - - if constexpr (IsRuntimeDataType) { - mainloop_args.runtime_data_type_a = runtime_input_datatype_a; - mainloop_args.runtime_data_type_b = runtime_input_datatype_b; - } - - - arguments = - { - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - mainloop_args, - collective_epilogue.to_args(problem_size), - hw_info, - scheduler_args - }; - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Creating gemm_op"); -#endif - Gemm gemm_op; - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling Gemm::get_workspace_size"); -#endif - size_t workspace_size = Gemm::get_workspace_size(arguments); -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); -#endif - cutlass::device_memory::allocation workspace(workspace_size); - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.can_implement"); -#endif - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - const auto error_str = cudaGetErrorString(error); - CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); - std::cerr << "This test is not supported: " << error_str << "\n"; - return true; - } - - // - // Run the GEMM - // - - if (profiling) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling profile"); -#endif - return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); - } - else { - cudaError_t result; -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.initialize"); -#endif - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - const auto error_str = cudaGetErrorString(error); - CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); - } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run"); -#endif - status = gemm_op.run(); - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - const auto error_str = cudaGetErrorString(error); - CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); - } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling cudaDeviceSynchronize"); -#endif - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST("TestbedImpl::run: cudaDeviceSynchronize reports non-success"); - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Calling this->verify"); -#endif - bool passed = this->verify(problem_size, alpha, beta); - if (!passed) { - CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify FAILED"); - cudaError_t error = cudaGetLastError(); - const auto error_str = cudaGetErrorString(error); - CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); - - std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta - << "\n"; - } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify passed"); - } -#endif - -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run: Reached end"); -#endif - return passed; - } - } -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity, - bool force_legacy_epilogue = false, - typename ElementA = typename Gemm::GemmKernel::ElementA, - typename ElementB = typename Gemm::GemmKernel::ElementB - , typename RuntimeDatatypeA = void* - , typename RuntimeDatatypeB = void* -> -struct Testbed3x { - - using TestBedImpl = typename detail::TestbedImpl< - Gemm, - ActivationFunctor, - force_legacy_epilogue, - ElementA, - ElementB - , RuntimeDatatypeA - , RuntimeDatatypeB - >; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - - using ElementAccumulator = typename TestBedImpl::ElementAccumulator; - using ElementCompute = typename TestBedImpl::ElementCompute; - using ElementScalar = typename TestBedImpl::ElementScalar; - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - // Detail Implementation - TestBedImpl impl_; - - // - // Methods - // - Testbed3x( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed) - : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} - - /// Executes one test - bool run( - typename TestBedImpl::ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic, - bool profiling = false, - detail::Iterations iterations = detail::Iterations{} - , RuntimeDatatypeA runtime_input_datatype_a = {} - , RuntimeDatatypeB runtime_input_datatype_b = {} - ) - { - return impl_.run( - problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode - , runtime_input_datatype_a, runtime_input_datatype_b - ); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestGemmPerf3x(int iterations = 20) { - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalar = ElementAccumulator; - bool passed = true; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - - std::vector problem_size_m = { 4608 }; - std::vector problem_size_n = { 4608 }; - std::vector problem_size_k = { 8192 }; - - Testbed3x testbed; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0), - RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, - true, // profiling - detail::Iterations{iterations}); - - if (!passed) { - return false; - } - } - } - } - - return true; -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -template < - typename Gemm, - typename RuntimeDataTypeA, - typename RuntimeDataTypeB, - bool force_legacy_epilogue = false> -bool TestRuntimeDataTypeSmall( - RuntimeDataTypeA runtime_input_datatype_a, - RuntimeDataTypeB runtime_input_datatype_b, - double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; - using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; - using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; - - using InternalElementA = typename Gemm::GemmKernel::ElementA; - using InternalElementB = typename Gemm::GemmKernel::ElementB; - - CtaShape_MNK cta_shape; - static constexpr int SmCount = 16; - static constexpr int MultiplierOffsetM = 1; - static constexpr int MultiplierOffsetN = 2; - static constexpr int MultiplierOffsetK = 3; - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - - float waves[] = {0.5, 1.25, 2.5}; - int cluster_m = 1; - int cluster_n = 1; - - std::vector problem_size_k; - if (override_problem_size_k.empty()) { - problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; - } - else { - problem_size_k = override_problem_size_k; - } - - if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { - typename DispatchPolicy::ClusterShape cluster_shape; - cluster_m = cute::size<0>(cluster_shape); - cluster_n = cute::size<1>(cluster_shape); - } - - [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - - std::vector decomposition_modes = {DecompositionMode::Heuristic}; - static constexpr bool UsesStreamKScheduler = cute::is_same_v; - if constexpr (UsesStreamKScheduler) { - decomposition_modes.push_back(DecompositionMode::DataParallel); - decomposition_modes.push_back(DecompositionMode::SplitK); - decomposition_modes.push_back(DecompositionMode::StreamK); - } - bool passed = true; - - for (float wave : waves) { - for (int k : problem_size_k) { - int grid_m, grid_n = 0; - int num_grid = int(wave * SmCount); - - if (cluster_m >= cluster_n) { - grid_m = cluster_m; - grid_n = num_grid / grid_m; - // Align grid_n to cluster_n - grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); - } - else { - grid_n = cluster_n; - grid_m = num_grid / grid_n; - // Align grid_m to cluster_m - grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); - } - - int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; - int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; - - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - for (DecompositionMode decomp_mode : decomposition_modes) { - std::vector problem_splits = {detail::Splits{1}}; - if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { - problem_splits.push_back(detail::Splits{2}); - } - for (auto splits : problem_splits) { - - if constexpr (cute::is_same_v && - cute::is_same_v) { - // e2m1_e2m1 - if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && - runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - else { - std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; - return false; - } - } - - else - if constexpr (cute::is_same_v && - cute::is_same_v) { - static_assert((cute::is_same_v || - cute::is_same_v || - cute::is_same_v) && - (cute::is_same_v || - cute::is_same_v || - cute::is_same_v), - "Runtime datatype must be selected with an appropriate static umbrella data type."); - if constexpr (cute::is_same_v && - cute::is_same_v) { - // e4m3_e2m1 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupport - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - // f6xf4 - else if constexpr (cute::is_same_v && - cute::is_same_v) { - // e3m2_e2m1 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupport - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - else if constexpr (cute::is_same_v && - cute::is_same_v) { - // e2m1_e2m1 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupport - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - else if constexpr (cute::is_same_v && - cute::is_same_v) { - // e4m3_e3m2 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupport - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - else if constexpr (cute::is_same_v && - cute::is_same_v) { - // e3m2_e2m3 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupported - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - else - if constexpr (cute::is_same_v && - cute::is_same_v) { - // e5m2_e5m2 - if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // e4m3_e5m2 - else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // e5m2_e4m3 - else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // e4m3_e4m3 - else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && - runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ - Testbed3x testbed(check_relative_equality, - use_device_scalars, - vector_scale_mode); - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - RasterOrderOptions::Heuristic, // raster_order - detail::MaxSwizzleSize(1), - splits, - decomp_mode, - false, - detail::Iterations{}, - runtime_input_datatype_a, - runtime_input_datatype_b - ); - } - // Unsupported - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - // Unsupported - else { - std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; - return false; - } - } - - else { - static_assert(cutlass::detail::dependent_false, - "Unsupported configuration for runtime datatype."); - } - - if (!passed) { - std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; - return false; - } - } // splits - } // decomposition_mode - } // k - } // waves - - return passed; -} - -template -bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED, - std::vector override_problem_size_k = {}) { - - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; - using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; - using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; - CtaShape_MNK cta_shape; - Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); - static constexpr int SmCount = 16; - static constexpr int MultiplierOffsetM = 1; - static constexpr int MultiplierOffsetN = 2; - static constexpr int MultiplierOffsetK = 3; - int max_alignment_k = 0; - int max_alignment_m = 0; - int max_alignment_n = 0; - - if constexpr (apply_alignment_offset) { - max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - } - // Alignment for SFD - if constexpr (detail::IsSfdEpi::value) { - using GmemLayoutTagScalefactor = typename Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor; - constexpr int SFDVecSize = Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::SFVecSize; - if constexpr (cute::is_same_v) { - max_alignment_n = std::lcm(max_alignment_n, SFDVecSize); - } - else { - max_alignment_m = std::lcm(max_alignment_m, SFDVecSize); - } - } - - float waves[] = {0.5, 1.25, 2.5}; - int cluster_m = 1; - int cluster_n = 1; - - std::vector problem_size_k; - if (override_problem_size_k.empty()) { - problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; - } - else { - problem_size_k = override_problem_size_k; - } - - if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { - typename DispatchPolicy::ClusterShape cluster_shape; - cluster_m = cute::size<0>(cluster_shape); - cluster_n = cute::size<1>(cluster_shape); - } - - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - - std::vector decomposition_modes = {DecompositionMode::Heuristic}; - static constexpr bool UsesStreamKScheduler = cute::is_same_v; - if constexpr (UsesStreamKScheduler) { - decomposition_modes.push_back(DecompositionMode::DataParallel); - decomposition_modes.push_back(DecompositionMode::SplitK); - decomposition_modes.push_back(DecompositionMode::StreamK); - } - bool passed = true; - - std::vector raster_order_options = {RasterOrderOptions::Heuristic}; - for (float wave : waves) { - for (int k : problem_size_k) { - int grid_m, grid_n = 0; - int num_grid = int(wave * SmCount); - - if (cluster_m >= cluster_n) { - grid_m = cluster_m; - grid_n = num_grid / grid_m; - // Align grid_n to cluster_n - grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); - } - else { - grid_n = cluster_n; - grid_m = num_grid / grid_n; - // Align grid_m to cluster_m - grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); - } - - int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; - int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; - int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, l}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - for (DecompositionMode decomp_mode : decomposition_modes) { - for (RasterOrderOptions raster_order : raster_order_options) { - std::vector problem_splits = {detail::Splits{1}}; - if constexpr (UsesStreamKScheduler) { - if (decomp_mode == DecompositionMode::SplitK) { - problem_splits.push_back(detail::Splits{2}); - problem_splits.push_back(detail::Splits{4}); - } - } - for (auto splits : problem_splits) { - try { - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - raster_order, // raster_order - detail::MaxSwizzleSize(0), - splits, - decomp_mode - ); - } - catch (std::exception const& e) { - EXPECT_TRUE(false) << "TestSmall: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: " << detail::raster_order_to_string(raster_order) - << ", max_swizzle_size: 1" - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} threw an exception: " << e.what(); - throw; - } - catch (...) { - EXPECT_TRUE(false) << "TestSmall: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: " << detail::raster_order_to_string(raster_order) - << ", max_swizzle_size: 1" - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} threw an exception (unknown)"; - throw; - } - EXPECT_TRUE(passed) << "TestSmall: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: " << detail::raster_order_to_string(raster_order) - << ", max_swizzle_size: 1" - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} failed"; - - if (!passed) { - std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; - return false; - } - } // splits - } // raster_order - } // decomposition_mode - } // k - } // waves - - return passed; -} - -template -bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED, - std::vector override_problem_size_k = {}) { - return TestSmall(alpha, - beta, - check_relative_equality, - use_device_scalars, - vector_scale_mode, - override_problem_size_k); -} - - - -template < - typename Gemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity -> -bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { - using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); - - int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); - int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); - if constexpr (std::is_base_of_v) { - max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); - max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); - } - std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; - std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; - - if constexpr (cute::is_same_v) { - problem_size_m.push_back(768); - problem_size_n.push_back(768); - } - - constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; - constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - - int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; - - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - std::vector decomposition_modes = {DecompositionMode::Heuristic}; - std::vector problem_splits = {detail::Splits{1}}; - static constexpr bool UsesStreamKScheduler = cute::is_same_v; - if constexpr (UsesStreamKScheduler) { - problem_splits.push_back(detail::Splits{2}); - problem_splits.push_back(detail::Splits{3}); - - decomposition_modes.push_back(DecompositionMode::DataParallel); - decomposition_modes.push_back(DecompositionMode::SplitK); - decomposition_modes.push_back(DecompositionMode::StreamK); - - // Use larger K sizes for stream-K tests - static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; - problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment_k}; - } - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; - std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; - - bool passed = true; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (auto raster_order : raster_orders) { - for (auto max_swizzle_size : max_swizzle_sizes) { - for (DecompositionMode decomp_mode : decomposition_modes) { - - std::vector problem_splits = {detail::Splits{1}}; - if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { - auto max_splits = (k + TileShapeK - 1) / TileShapeK; - if (max_splits > 2) { - problem_splits.push_back(detail::Splits{2}); - } - if (max_splits > 3) { - problem_splits.push_back(detail::Splits{3}); - } - - problem_splits.push_back(detail::Splits{max_splits}); - - // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this - // case, split-K will fall back to a splitting factor of `max_splits`. - problem_splits.push_back(detail::Splits{max_splits + 1}); - } - for (auto splits : problem_splits) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - try { - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - raster_order, - max_swizzle_size, - splits, - decomp_mode - ); - } - catch (std::exception const& e) { - EXPECT_TRUE(false) << "TestAll: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: ???" - << ", max_swizzle_size: " << static_cast(max_swizzle_size) - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} threw an exception: " << e.what(); - throw; - } - catch (...) { - EXPECT_TRUE(false) << "TestAll: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: ???" - << ", max_swizzle_size: " << static_cast(max_swizzle_size) - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} threw an exception (unknown)"; - throw; - } - - EXPECT_TRUE(passed) << "TestAll: testbed.run {" - << "m: " << m << ", n: " << n << ", k: " << k - << ", alpha: " << alpha << ", beta: " << beta - << ", raster_order: ???" - << ", max_swizzle_size: " << static_cast(max_swizzle_size) - << ", splits: " << static_cast(splits) - << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) - << "} failed"; - - if (!passed) { - std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; - return false; - } - } // splits - } // decomposition_mode - } // max_swizzle_size - } // raster_order - } // k - } // n - } // m - - // if we do support batched GEMM, just run one test on it to save on test time - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - auto problem_size = ProblemShapeType{256 + max_alignment_m, 256 + max_alignment_n, 160 + max_alignment_k, /* l */ 3}; - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - - return passed; -} - -template -bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { - return TestAll(alpha, beta, check_relative_equality); -} - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp deleted file mode 100644 index f18a7b39cbfe7dfb8d3251b2750e49261522de8a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp +++ /dev/null @@ -1,1742 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Testbed and host reference for EVT unittest -*/ - - -#pragma once -#include "gemm_testbed_3x.hpp" - -namespace test { -namespace gemm { -namespace device { - -/// Host-side tapply, tapply in cute is HOST_DEVICE -template -constexpr auto -tapply(T&& t, F&& f, G&& g, cute::seq) -{ - return g(f(std::get(static_cast(t)))...); -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT: Base class for EVT Node - -template < class ElementCompute_ > -class HostEVTNodeBase { -public: - using ElementCompute = ElementCompute_; - -private: - bool check_relative_equality_; - // Factors used for calculating relative equality. These default - // values are borrowed from those used by default in the CUTLASS - // profiler for performing relative equality checks. - float epsilon_ = 0.05f; - float nonzero_floor_ = 1.0f / 256.0f; - -public: - HostEVTNodeBase(){} - HostEVTNodeBase(bool check_relative_equality): - check_relative_equality_(check_relative_equality) { } - - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - if (check_relative_equality_) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, Element(epsilon_), Element(nonzero_floor_) - ); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - void* get_tensor_C_ptr() { - return nullptr; - } - - void* get_tensor_D_ptr() { - return nullptr; - } - - bool compare_reference(std::stringstream& error_ss) { - return true; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Accumulator - -template< class ElementCompute = float > -class HostAccumulator: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - - struct Arguments { }; - -public: - HostAccumulator(){} - template - HostAccumulator(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - :Base(check_relative_equality) {} - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - cutlass::NumericConverter accumulator_converter; - return accumulator_converter(acc); - } - - Arguments get_arguments() { - return Arguments{}; - } - - auto get_flatten_arguments() { - return cute::make_tuple(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Scalar Broadcast - -template < - int Value, - int BroadcastCount = 1, - class StrideMNL = cute::Stride, - template class ReductionFn = cutlass::multiplies, - class ElementCompute = float -> -class HostScalarBroadcast : public HostEVTNodeBase { -public: - - using Base = HostEVTNodeBase; - struct Arguments { - ElementCompute scalar[BroadcastCount] = {0}; - ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; - StrideMNL dScalar[BroadcastCount] = {}; - }; -private: - ElementCompute scalar_{}; - StrideMNL dScalar{}; - ElementCompute scalar_reduced_{}; -public: - HostScalarBroadcast(){} - - template - HostScalarBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - : Base(check_relative_equality), scalar_(ElementCompute(Value)) { - scalar_ = ElementCompute(Value); - scalar_reduced_ = scalar_; - for (int i = 1; i < BroadcastCount; ++i) { - scalar_reduced_ = ReductionFn{}(scalar_reduced_, ElementCompute(Value)); - } - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - - return scalar_reduced_; - } - - bool compare_reference(std::stringstream& error_ss) { - error_ss << "Scalar: " << float(scalar_) << "\n\n"; - return true; - } - - Arguments get_arguments() { - if constexpr (BroadcastCount == 1) - return Arguments{{scalar_}, {nullptr}, {dScalar}}; - else if constexpr (BroadcastCount == 2) - return Arguments{{scalar_, scalar_}, {nullptr, nullptr}, {dScalar, dScalar}}; - else if constexpr (BroadcastCount == 3) - return Arguments{{scalar_, scalar_, scalar_}, {nullptr, nullptr, nullptr}, {dScalar, dScalar, dScalar}}; - else - return Arguments{{scalar_}, {nullptr}, {dScalar}}; - } - - auto get_flatten_arguments() { - if constexpr (BroadcastCount == 1) { - return cute::make_tuple(scalar_, nullptr); - } - else if constexpr (BroadcastCount == 2) { - return cute::make_tuple(scalar_, scalar_, nullptr, nullptr); - } - else if constexpr (BroadcastCount == 3) { - return cute::make_tuple(scalar_, scalar_, scalar_, nullptr, nullptr, nullptr); - } - else { - return cute::make_tuple(scalar_, nullptr); - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Row Broadcast -template < - typename ElementBias_, - typename StrideMNL = cute::Stride, - typename ElementCompute = float -> -class HostRowBroadcast: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using ElementBias = ElementBias_; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - struct Arguments { - ElementBias const* ptr_row = nullptr; - ElementBias null_default = ElementBias(0); - StrideMNL dRow = {}; - }; -private: - cutlass::NumericConverter bias_converter_; - cutlass::HostTensor bias_; - int N_; -public: - HostRowBroadcast(){} - template - HostRowBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - : Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - N_ = cute::get<1>(problem_shape_MNKL); - bias_.resize(cutlass::Coord<1>(N_)); - - EXPECT_TRUE( - detail::initialize_tensor( - bias_.host_view(), cutlass::Distribution::Uniform, - seed - ) - ); - bias_.sync_device(); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - auto TensorBias = cute::make_tensor(bias_.host_data(), - cute::make_layout(cute::make_shape(cute::_1{}, N_))); - - return bias_converter_(TensorBias(1, n + n_b)); - } - - bool compare_reference(std::stringstream& error_ss) { - error_ss - << "PerColumnBias = \n" << bias_.host_view() << "\n\n"; - return true; - } - - Arguments get_arguments() { - return {bias_.device_data()}; - } - - auto get_flatten_arguments() { - return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); - } - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Column Broadcast -template < - typename ElementBias_, - typename StrideMNL = cute::Stride, - typename ElementCompute = float -> -class HostColBroadcast: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using ElementBias = ElementBias_; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - struct Arguments { - ElementBias const* ptr_row = nullptr; - ElementBias null_default = ElementBias(0); - StrideMNL dRow = {}; - }; -private: - cutlass::NumericConverter bias_converter_; - cutlass::HostTensor bias_; - int M_; -public: - HostColBroadcast(){} - template - HostColBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - : Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - M_ = cute::get<0>(problem_shape_MNKL); - bias_.resize(cutlass::Coord<1>(M_)); - - EXPECT_TRUE( - detail::initialize_tensor( - bias_.host_view(), cutlass::Distribution::Uniform, - seed - ) - ); - bias_.sync_device(); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - auto TensorBias = cute::make_tensor(bias_.host_data(), - cute::make_layout(cute::make_shape(M_, cute::_1{}))); - - return bias_converter_(TensorBias(m + m_b, 1)); - } - - bool compare_reference(std::stringstream& error_ss) { - error_ss - << "PerRowBias = \n" << bias_.host_view() << "\n\n"; - return true; - } - - Arguments get_arguments() { - return {bias_.device_data()}; - } - - auto get_flatten_arguments() { - return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Aux Load - -template < - typename ElementAuxLoad_, - typename LayoutTagAux_, - bool isC = false, - typename ElementCompute = float -> -class HostAuxLoad: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using ElementAuxLoad = ElementAuxLoad_; - using LayoutTagAux = LayoutTagAux_; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - struct Arguments_Aux { - ElementAuxLoad const *ptr_aux = nullptr; - ElementAuxLoad null_default = ElementAuxLoad(0); - StrideAux dAux = {}; - }; - - struct Arguments_C {}; - - using Arguments = cute::conditional_t; - -private: - cutlass::NumericConverter aux_load_converter_; - cutlass::HostTensor tensor_aux_load_; - - int M_, N_, L_; - - StrideAux stride_aux_; -public: - HostAuxLoad(){} - template - HostAuxLoad(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - : Base(check_relative_equality) { - auto problem_shape_NMKL = cute::append<4>(problem_size, 1); - auto [M_, N_, K, L_] = problem_shape_NMKL; - auto aux_coord = cutlass::make_Coord(M_ * L_, N_); - tensor_aux_load_.resize( - aux_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory( - aux_coord, typename LayoutTagAux::Stride() - ) - ); - EXPECT_TRUE( - detail::initialize_tensor( - tensor_aux_load_.host_view(), - cutlass::Distribution::Uniform, - seed - ) - ); - tensor_aux_load_.sync_device(); - stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - - - auto TensorAuxLoad = cute::make_tensor(tensor_aux_load_.host_data(), - cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); - return aux_load_converter_(TensorAuxLoad(m + m_b, n + n_b, l)); - } - - bool compare_reference(std::stringstream& error_ss) { - if constexpr (!isC) { - error_ss - << "AuxLoad = \n" << tensor_aux_load_.host_view()<< "\n\n"; - } - return true; - } - - void* get_tensor_C_ptr() { - if constexpr (isC) { - return static_cast(tensor_aux_load_.device_data()); - } - else { - return nullptr; - } - } - - Arguments get_arguments() { - if constexpr (isC) - return {}; - else - return {tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_}; - } - - auto get_flatten_arguments() { - if constexpr (isC) - return cute::make_tuple(); - else - return cute::make_tuple(tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Compute - -template -T* findNonNullPtr(T* first_ptr) { - return first_ptr; -} - -template -T* findNonNullPtr(T* first_ptr, Args... args) { - if (first_ptr) { - return first_ptr; - } - return findNonNullPtr(args...); -} - -template < - template class ComputeOp_, - typename ElementCompute = float -> -class HostCompute: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using ComputeOp = ComputeOp_; - - struct Arguments { - struct OpArgs {} op; - }; -private: - ComputeOp op_; -public: - HostCompute(){} - template - HostCompute(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): - Base(check_relative_equality) { } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc, Args... frg_inputs) { - return op_(frg_inputs...); - } - - Arguments get_arguments(){ - return {}; - } - - auto get_flatten_arguments() { - return cute::make_tuple(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Aux Store - -template < - class ElementAuxStore_, - typename LayoutTagAux_, - bool isD = false, - bool isRelu = false, - typename ElementCompute = float -> -class HostAuxStore: public HostEVTNodeBase { -public: - using ElementAuxStore = ElementAuxStore_; - using LayoutTagAux = LayoutTagAux_; - - using Base = HostEVTNodeBase; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - struct Arguments_Aux { - struct OpArgs { - ElementAuxStore* ptr_aux = nullptr; - StrideAux dAux = {}; - } op; - }; - - struct Arguments_D {}; - - using Arguments = cute::conditional_t; - - -private: - cutlass::NumericConverter destination_converter_; - cutlass::HostTensor tensor_aux_store_; - cutlass::HostTensor reference_aux_store_; - int M_, N_, L_; - StrideAux stride_aux_; -public: - HostAuxStore(){} - template - HostAuxStore(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): - Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M_, N_, K, L_] = problem_shape_MNKL; - auto aux_coord = cutlass::make_Coord(M_ * L_, N_); - tensor_aux_store_.resize( - aux_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory( - aux_coord, typename LayoutTagAux::Stride() - ) - ); - - reference_aux_store_.resize( - aux_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory( - aux_coord, typename LayoutTagAux::Stride() - ) - ); - tensor_aux_store_.sync_device(); - stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc, ElementCompute child_0_result) { - - auto TensorAuxStore = cute::make_tensor(detail::make_iterator(static_cast(reference_aux_store_.host_data())), - cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); - if constexpr (isRelu) - TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result >= 0); - else - TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result); - return child_0_result; - } - - bool compare_reference(std::stringstream& error_ss) { - // Verify the store node - tensor_aux_store_.sync_host(); - - bool equal = this->equality_check(reference_aux_store_.host_view(), tensor_aux_store_.host_view()); - if (!equal) { - error_ss - << "\n\nReference =\n" << reference_aux_store_.host_view() - << "\n\nComputed =\n" << tensor_aux_store_.host_view() << "\n\n"; - } - return equal; - } - - void* get_tensor_D_ptr() { - if constexpr (isD) - return static_cast(tensor_aux_store_.device_data()); - else - return nullptr; - } - - Arguments get_arguments() { - if constexpr (isD) { - return {}; - } - else { - return {tensor_aux_store_.device_data(), stride_aux_}; - } - } - - auto get_flatten_arguments() { - if constexpr (isD) { - return cute::make_tuple(); - } - else { - return cute::make_tuple(tensor_aux_store_.device_data(), stride_aux_); - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Row Reduce - -template < - template class ReduceFn, - typename ElementReduce, - bool FinalReduction = true, // Should match the FinalReduction in Device type - typename CtaTileShapeMNK = cute::Shape, - typename ElementCompute = float -> -class HostRowReduce: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementDst = cute::conditional_t; - - static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); - static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); - - struct Arguments { - struct OpArgs { - ElementReduce* ptr_row = nullptr; - ElementCompute reduce_identity = 0; - cute::Stride dRow = {}; - } op; - }; - -private: - cutlass::NumericConverter destination_converter_; - cutlass::HostTensor tensor_row_reduce_; - cutlass::HostTensor reduce_buffer_; - cutlass::HostTensor reference_row_reduce_; - int N_; - ReduceFn reduce_fn_; - - int extent_m_; - int extent_n_; - int extent_l_; -public: - HostRowReduce(){} - template - HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): - Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - N_ = cute::get<1>(problem_shape_MNKL); - if constexpr (FinalReduction) { - tensor_row_reduce_.resize(cutlass::Coord<1>(N_)); - reference_row_reduce_.resize(cutlass::Coord<1>(N_)); - reduce_buffer_.resize(cutlass::Coord<1>(N_)); - } - else { - auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); - extent_m_ = cute::get<0>(NumTile); - extent_n_ = cute::get<1>(NumTile) * TileN; - extent_l_ = cute::get<2>(NumTile); - auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); - tensor_row_reduce_.resize(shape); - reference_row_reduce_.resize(shape); - reduce_buffer_.resize(shape); - } - - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc, ElementCompute child_0_result) { - if constexpr (FinalReduction) { - auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(cute::_1{}, N_))); - TensorRowReduce(1, n + n_b) = reduce_fn_(TensorRowReduce(1, n + n_b), child_0_result); - } - else { - auto TensorRowReduce = cute::make_tensor( - reduce_buffer_.host_data(), - cute::make_layout( - cute::make_shape(extent_m_, extent_n_, extent_l_), - cute::make_stride(extent_n_, 1, extent_m_ * extent_l_) - ) - ); - TensorRowReduce((m+m_b)/TileM, n+n_b, l) = reduce_fn_(TensorRowReduce((m+m_b)/TileM, n+n_b, l), child_0_result); - } - - return child_0_result; - } - - bool compare_reference(std::stringstream& error_ss) { - // Verify the store node - tensor_row_reduce_.sync_host(); - - auto TensorRowReduce = cute::make_tensor(reference_row_reduce_.host_data(), - cute::make_layout(cute::make_shape(reference_row_reduce_.size()))); - - auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(reduce_buffer_.size()))); - - // Filling the reference tensor with the reduce buffer - for (uint64_t n = 0; n < size(TensorRowReduce); n ++) { - TensorRowReduce(n) = destination_converter_(TensorReduceBuffer(n)); - } - - bool equal = this->equality_check(reference_row_reduce_.host_view(), tensor_row_reduce_.host_view()); - if (!equal) { - error_ss - << "\n\nRow Reduce Reference =\n" << reference_row_reduce_.host_view() - << "\n\nRow Reduce Computed =\n" << tensor_row_reduce_.host_view() << "\n\n"; - } - return equal; - } - - Arguments get_arguments() { - return {tensor_row_reduce_.device_data()}; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Column Reduce - -template < - template class ReduceFn, - typename ElementReduce, - bool FinalReduction = true, // Should match the FinalReduction in Device type - typename CtaTileShapeMNK = cute::Shape, - typename ElementCompute = float -> -class HostColumnReduce: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementDst = cute::conditional_t; - - static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); - static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); - - struct Arguments { - struct OpArgs { - ElementReduce* ptr_col = nullptr; - ElementCompute reduce_identity = 0; - cute::Stride dRow = {}; - } op; - }; - -private: - cutlass::NumericConverter destination_converter_; - cutlass::HostTensor tensor_column_reduce_; - cutlass::HostTensor reduce_buffer_; - cutlass::HostTensor reference_column_reduce_; - int M_; - ReduceFn reduce_fn_; - - int extent_m_; - int extent_n_; - int extent_l_; -public: - HostColumnReduce(){} - template - HostColumnReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): - Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - M_ = cute::get<0>(problem_shape_MNKL); - - if constexpr (FinalReduction) { - tensor_column_reduce_.resize(cutlass::Coord<1>(M_)); - reference_column_reduce_.resize(cutlass::Coord<1>(M_)); - reduce_buffer_.resize(cutlass::Coord<1>(M_)); - } - else { - auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); - extent_m_ = cute::get<0>(NumTile) * TileM; - extent_n_ = cute::get<1>(NumTile); - extent_l_ = cute::get<2>(NumTile); - auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); - tensor_column_reduce_.resize(shape); - reference_column_reduce_.resize(shape); - reduce_buffer_.resize(shape); - } - - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc, ElementCompute child_0_result) { - auto TensorColReduce = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(M_, cute::_1{}))); - if constexpr (FinalReduction) { - TensorColReduce(m + m_b, 1) = reduce_fn_(TensorColReduce(m + m_b, 1), child_0_result); - } - else { - auto shape = reduce_buffer_.extent(); - auto TensorColReduce = cute::make_tensor( - reduce_buffer_.host_data(), - cute::make_layout( - cute::make_shape(extent_m_, extent_n_, extent_l_), - cute::make_stride(1, extent_m_, extent_m_ * extent_l_) - ) - ); - TensorColReduce(m+m_b, (n+n_b)/TileN, l) = reduce_fn_(TensorColReduce(m+m_b, (n+n_b)/TileN, l), child_0_result); - } - return child_0_result; - } - - bool compare_reference(std::stringstream& error_ss) { - // Verify the store node - tensor_column_reduce_.sync_host(); - - auto TensorColReduce = cute::make_tensor(reference_column_reduce_.host_data(), - cute::make_layout(cute::make_shape(reference_column_reduce_.size()))); - - auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(reduce_buffer_.size()))); - - // Filling the reference tensor with the reduce buffer - for (uint64_t m = 0; m < size(TensorColReduce); m ++) { - TensorColReduce(m) = destination_converter_(TensorReduceBuffer(m)); - } - - bool equal = this->equality_check(reference_column_reduce_.host_view(), tensor_column_reduce_.host_view()); - if (!equal) { - error_ss - << "\n\nColumn Reduce Reference =\n" << reference_column_reduce_.host_view() - << "\n\nColumn Reduce Computed =\n" << tensor_column_reduce_.host_view() << "\n\n"; - } - return equal; - } - - Arguments get_arguments() { - return {tensor_column_reduce_.device_data()}; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// EVT - Scalar Reduce - -template < - template class ReduceFn, - typename ElementReduce, - typename ElementCompute = float, - bool enabled = true -> -class HostScalarReduce: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - struct Arguments { - struct OpArgs { - ElementReduce* ptr_scalar = nullptr; - ElementCompute reduce_identity = 0; - cute::Stride dScalar = {}; - } op; - }; - -private: - cutlass::NumericConverter destination_converter_; - cutlass::HostTensor tensor_scalar_reduce_; - cutlass::HostTensor reduce_buffer_; - cutlass::HostTensor reference_scalar_reduce_; - ReduceFn reduce_fn_; -public: - HostScalarReduce(){} - template - HostScalarReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): - Base(check_relative_equality) { - tensor_scalar_reduce_.resize(cutlass::Coord<1>(1)); - reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); - reduce_buffer_.resize(cutlass::Coord<1>(1)); - - tensor_scalar_reduce_.sync_device(); - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc, ElementCompute child_0_result) { - auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(cute::_1{}))); - TensorRowReduce(0) = reduce_fn_(TensorRowReduce(0), child_0_result); - return child_0_result; - } - - bool compare_reference(std::stringstream& error_ss) { - if constexpr (enabled) { - // Verify the store node - tensor_scalar_reduce_.sync_host(); - - auto TensorRowReduce = cute::make_tensor(reference_scalar_reduce_.host_data(), - cute::make_layout(cute::make_shape(cute::_1{}))); - - auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), - cute::make_layout(cute::make_shape(cute::_1{}))); - - // Filling the reference tensor with the reduce buffer - TensorRowReduce(0) = destination_converter_(TensorReduceBuffer(0)); - - bool equal = this->equality_check(reference_scalar_reduce_.host_view(), tensor_scalar_reduce_.host_view()); - if (!equal) { - error_ss - << "\n\nScalar Reduce Reference =\n" << reference_scalar_reduce_.host_view() - << "\n\nScalar Reduce Computed =\n" << tensor_scalar_reduce_.host_view() << "\n\n"; - } - return equal; - } - else { - return true; - } - - } - - Arguments get_arguments() { - return {tensor_scalar_reduce_.device_data()}; - } - - auto get_flatten_arguments() { - return cute::make_tuple(tensor_scalar_reduce_.device_data()); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Host EVT wrapper - -/// The ArgumentPack is used to model the alignment when num ops <= 4 -template -struct ArgumentPack; - -template -struct ArgumentPack { - T arg; - ArgumentPack(T first): - arg(first) {} -}; - -template -struct ArgumentPack { - First arg; - ArgumentPack rest_args; - - ArgumentPack(First first, Rest... rest) : - arg(first), rest_args(rest...) {} -}; - - -/// Base class for Host Visitor -template -struct HostVisitorBase: public HostEVTNodeBase { -public: - using Base = HostEVTNodeBase; - - using Arguments_struct = ArgumentPack; - using Arguments_tuple = cute::tuple; - - constexpr static int Rm1 = sizeof...(Ops); - constexpr static bool cond = Rm1 > 4; - using Arguments = cute::conditional_t; - - std::tuple ops; - - HostVisitorBase(){} - template - HostVisitorBase(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - :Base(check_relative_equality), - ops(test::gemm::device::tapply(std::tuple{}, - [&] (auto&& op) { - using Op = cute::remove_cvref_t; - return Op(problem_size, check_relative_equality, seed); - }, - [] (auto&&... _ops) { - return std::make_tuple(_ops...); - }, - cute::make_seq{} - )){ } - - bool compare_reference(std::stringstream& error_ss) { - return cute::detail::tapply(ops, - [&](auto& op) { - return op.compare_reference(error_ss); - }, - [&] (auto&&... inputs) { - return arrayAnd(inputs...); - }, - cute::make_seq{} - ); - } - - void* get_tensor_C_ptr() { - return cute::detail::tapply(ops, - [&](auto& op) { - return op.get_tensor_C_ptr(); - }, - [&] (auto&&... inputs) { - return findNonNullPtr(inputs...); - }, - cute::make_seq{} - ); - } - - void* get_tensor_D_ptr() { - return cute::detail::tapply(ops, - [&](auto& op) { - return op.get_tensor_D_ptr(); - }, - [&] (auto&&... inputs) { - return findNonNullPtr(inputs...); - }, - cute::make_seq{} - ); - } - - Arguments get_arguments() { - return test::gemm::device::tapply(ops, - [&](auto& op) { - return op.get_arguments(); - }, - [&] (auto&&... args) { - if constexpr (Rm1 > 4) { - return cute::make_tuple(args...); - } - else { - return Arguments(args...); - } - }, - cute::make_seq{} - ); - } - - auto get_flatten_arguments() { - return test::gemm::device::tapply(ops, - [&](auto& op) { - return op.get_flatten_arguments(); - }, - [&] (auto&&... args) { - return flatten(cute::make_tuple(args...)); - }, - cute::make_seq{} - ); - } - - bool arrayAnd(bool passed) { - return passed; - } - - template - bool arrayAnd(bool first_passed, Args... passed) { - if (first_passed) { - return arrayAnd(passed...); - } - return first_passed; - } - -}; - - -/// Tree-struct visitor -template -struct HostTreeVisitor: public HostVisitorBase { -public: - using ElementCompute = typename NodeOp::Base::ElementCompute; - using Base = HostVisitorBase; - using Arguments = typename Base::Arguments; - - constexpr static int Rm1 = sizeof...(ChildOps); - - HostTreeVisitor(){} - template - HostTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - :Base(problem_size, check_relative_equality, seed){ } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - return cute::detail::tapply(this->ops, - [&] (auto& op) { - return op.visit(m, n, l, m_b, n_b, acc); - }, - [&] (auto&&... frg_inputs) { - return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); - }, - cute::make_seq{} - ); - } -}; - - -/// General Graph visitor -template -struct HostTopoVisitor: public HostVisitorBase { -public: - using Base = HostVisitorBase; - constexpr static int Rm1 = Base::Rm1; - using Arguments = typename Base::Arguments; - -private: - ElementCompute frg_outputs_[Rm1]; -public: - HostTopoVisitor(){} - template - HostTopoVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - :Base(problem_size, check_relative_equality, seed) { } - - template - ElementCompute visit_( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - frg_outputs_[I] = cute::transform_apply(cute::get(EdgeTuple{}), - [&] (auto&& _E) { - constexpr int e = cute::remove_cvref_t::value; - return frg_outputs_[e]; - }, - [&] (auto const&... frg_inputs) { - ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); - return res; - } - ); - - if constexpr (I < Rm1 - 1) { - return visit_(m, n, l, m_b, n_b, acc); - } - else { - return frg_outputs_[I]; - } - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - - return visit_(m, n, l, m_b, n_b, acc); - } - -}; - - -/// SplitTree visitor -template -struct HostSplitTreeVisitor: public HostVisitorBase { -public: - using Base = HostVisitorBase; - using Arguments = typename Base::Arguments; - - constexpr static int Rm2 = sizeof...(AuxOutTrees); - -private: - ElementCompute frg_input_; -public: - HostSplitTreeVisitor(){} - template - HostSplitTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) - :Base(problem_size, check_relative_equality, seed) { } - - template - void visitAux( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator frag) { - std::get(this->ops).visit(m, n, l, m_b, n_b, frag); - - if constexpr (I < Rm2 - 1) { - return visitAux(m, n, l, m_b, n_b, frag); - } - else { - return; - } - } - - template - ElementCompute visit( - int64_t m, int64_t n, int64_t l, int m_b, int n_b, - ElementAccumulator acc) { - - /// Compute the input tree - frg_input_ = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); - - /// Compute the aux out tree - visitAux(m, n, l, m_b, n_b, frg_input_); - /// Visit the output tree - return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input_); - } -}; - -/// Universal testbed for EVT w/o smem -template -class Testbed3xEVTnoSmem { -public: - // The EVT Module to test - using EVTModule = EVT; //typename EVT::EVTModule; - - using TestBedImpl = typename detail::TestbedImpl; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - using ElementAccumulator = typename Kernel::ElementAccumulator; - using ElementC = typename Kernel::ElementC; - using ElementD = typename Kernel::ElementD; - - using ProblemShapeType = typename Kernel::ProblemShape; - - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - // - // Methods - // - Testbed3xEVTnoSmem( - bool check_relative_equality_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), - check_relative_equality(check_relative_equality_) { } - - Testbed3xEVTnoSmem( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed ) : - impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), - check_relative_equality(false) { } - - /// Initializes data structures - void initialize(ProblemShapeType problem_size) { - // - // Allocate the GEMM workspace for A/B tensor - // - impl_.initialize(problem_size); - } - // Detail Implementation - TestBedImpl impl_; - - // Whether to use relative equality checks - bool check_relative_equality; - - bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { - - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto K = cute::get<2>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - - auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), - cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); - auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), - cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); - auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); - - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - - /// Reference Kernel - static int constexpr kBlockM = 64; - static int constexpr kBlockN = 64; - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - /// Epilogue EVT - for (int n_b = 0; n_b < kBlockN; ++n_b) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { - host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); - } - } - } - } - } - } - - std::stringstream error_ss; - bool passed = host_reference.compare_reference(error_ss); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; - - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K - << ", Batch count = " << L << "\n\n"; - - file - << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() - << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view(); - - file << error_ss.str(); - } - - return passed; - } - - bool run( - ProblemShapeType problem_size, - RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, - detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}, - DecompositionMode decomposition_mode = DecompositionMode::Heuristic, - int iterations = 20, - bool profiling = false) { - // Fail test if insufficient CUDA device - if (!impl_.sufficient()) { - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } - // - // Initialize the Gemm operator - // - - typename Gemm::Arguments arguments; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - if (not profiling) { - impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = impl_.sm_count; - } - else { - impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = impl_.sm_count; - } - - typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; - if constexpr (cute::is_same_v) { - scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; - } - else { - scheduler_args = { static_cast(max_swizzle), raster_order }; - } - - /// Initializes data structures - /// A/B/C/D Tensor - initialize(problem_size); - - /// Initialize the epilogue arguments - EVTModule host_reference(problem_size, check_relative_equality, 2024); - - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - { - impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, - impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b - }, - {}, - hw_info, - scheduler_args - }; - - // Filling in the thread arguments - if constexpr (FlatArgs) { - auto epilogue_args = host_reference.get_flatten_arguments(); - std::memcpy(&arguments.epilogue.thread, &epilogue_args, sizeof(epilogue_args)); - - arguments.epilogue.ptr_C = static_cast(host_reference.get_tensor_C_ptr()); - arguments.epilogue.dC = impl_.collective_epilogue.stride_c; - - arguments.epilogue.ptr_D = static_cast(host_reference.get_tensor_D_ptr()); - arguments.epilogue.dD = impl_.collective_epilogue.stride_d; - } - else { - auto epilogue_args = host_reference.get_arguments(); - std::memcpy(&arguments.epilogue, &epilogue_args, sizeof(epilogue_args)); - } - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // - // Run the GEMM - // - if (profiling) { - return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); - } - else { - cudaError_t result; - status = gemm_op.initialize(arguments, workspace.get()); - status = gemm_op.run(); - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - bool passed = this->verify(problem_size, host_reference); - if (!passed) { - std::cout << "Error : Failed \n"; - } - - return passed; - } -}; - -/// Universal testbed for EVT -template -class Testbed3xEVT { -public: - // The EVT Module to test - using EVTModule = typename EVT::EVTModule; - - using TestBedImpl = typename detail::TestbedImpl; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - using ElementAccumulator = typename Kernel::ElementAccumulator; - using ElementC = typename Kernel::ElementC; - using ElementD = typename Kernel::ElementD; - - using ProblemShapeType = typename Kernel::ProblemShape; - - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - using LayoutTagC = typename TestBedImpl::LayoutTagC; - using LayoutTagD = typename TestBedImpl::LayoutTagD; - - // - // Methods - // - Testbed3xEVT( - bool check_relative_equality_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), - check_relative_equality(check_relative_equality_) { } - - Testbed3xEVT( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), - check_relative_equality(false) { } - - Testbed3xEVT( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, - CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), - check_relative_equality(false) { } - - /// Initializes data structures - void initialize(ProblemShapeType problem_size) { - // - // Allocate the GEMM workspace for A/B tensor - // - impl_.initialize(problem_size); - } - // Detail Implementation - TestBedImpl impl_; - - // Whether to use relative equality checks - bool check_relative_equality; - - bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { - - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto K = cute::get<2>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - - auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), - cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); - auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), - cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); - auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); - - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - - /// Reference Kernel - static int constexpr kBlockM = 64; - static int constexpr kBlockN = 64; - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - /// Epilogue EVT - for (int n_b = 0; n_b < kBlockN; ++n_b) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { - host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); - } - } - } - } - } - } - - std::stringstream error_ss; - bool passed = host_reference.compare_reference(error_ss); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; - - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K - << ", Batch count = " << L << "\n\n"; - - file - << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() - << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() - << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; - - file << error_ss.str(); - } - - return passed; - } - - bool run( - ProblemShapeType problem_size, - bool profiling = false, - int iterations = 20, - int splits = 1) { - // Fail test if insufficient CUDA device - if (!impl_.sufficient()) { - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } - // - // Initialize the Gemm operator - // - - typename Gemm::Arguments arguments; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - if (not profiling) { - impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = impl_.sm_count; - } - else { - impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = impl_.sm_count; - } - - typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; - if constexpr (cute::is_same_v) { - scheduler_args = { splits }; - } - - /// Initializes data structures - /// A/B/C/D Tensor - initialize(problem_size); - - /// Initialize the epilogue arguments - EVTModule host_reference(problem_size, check_relative_equality, 2024); - - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - { - impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, - impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b - }, - { // Epilogue arguments - {}, // thread - static_cast(host_reference.get_tensor_C_ptr()), - impl_.collective_epilogue.stride_c, - static_cast(host_reference.get_tensor_D_ptr()), - impl_.collective_epilogue.stride_d - }, // Epilogue arguments end - hw_info, - scheduler_args - }; - - // Filling in the thread arguments - typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); - std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // - // Run the GEMM - // - if (profiling) { - return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); - } - else { - cudaError_t result; - status = gemm_op.initialize(arguments, workspace.get()); - status = gemm_op.run(); - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - bool passed = this->verify(problem_size, host_reference); - if (!passed) { - std::cout << "Error : Failed \n"; - } - - return passed; - } -}; - -template -bool TestAllEVT(bool check_relative_equality = false) { - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; - std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; - - if constexpr (cute::is_same_v) { - problem_size_m.push_back(768); - problem_size_n.push_back(768); - } - - constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; - constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - - std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - - Testbed3xEVT testbed(check_relative_equality); - bool passed = true; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run(problem_size); - - if (!passed) { - return false; - } - } - } - } - - // if we do support batched GEMM, just run one test on it to save on test time - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; - passed = testbed.run( - problem_size - ); - - if (!passed) { - return false; - } - } - - return passed; -} - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp deleted file mode 100644 index cbc54ec582d88d9039968d8153cf6127a06ec274..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ /dev/null @@ -1,2409 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Testbed for Ptr-Array and Grouped GEMM interface -*/ - -#pragma once - -#include -#include -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gett.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/fusion/operations.hpp" -#include "cutlass/complex.h" -#include "testbed_utils.h" - -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/gemm/gemm.h" - -#include "cute/int_tuple.hpp" -#include "cute/layout.hpp" -#include "cute/numeric/int.hpp" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -enum class ScalarLoc { - ON_HOST = 0, - ON_DEVICE = 1 -}; - -enum class VectorScale { - DISABLED = 0, - ENABLED = 1 -}; - -enum class CheckEquality { - EXACT = 0, - RELATIVE = 1 -}; - -namespace detail{ - -// Helper classes that take default data type when -// the Gemm::EpilogueOutputOp does not have ElementCompute -// and ElementScalar. -// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) -template -struct ElementComputeType { - using Type = Default; -}; - -template -struct ElementComputeType> { - using Type = typename Gemm::EpilogueOutputOp::ElementCompute; -}; - -template -struct ElementScalarType { - using Type = Default; -}; - -template -struct ElementScalarType> { - using Type = typename Gemm::EpilogueOutputOp::ElementScalar; -}; - - -template -struct IsF8F6F4Kernel { - static constexpr bool value = false; -}; - -template -struct IsF8F6F4Kernel> { - static constexpr bool value = true; -}; - - -// The maximum swizzle size to use -// -// This class, like Splits above makes it harder to confuse -// the order of arguments of the various run(...) functions in this file. -class MaxSwizzleSize { -public: - MaxSwizzleSize() = default; - - template && - !cute::is_same_v)) > - explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} - explicit operator int() const { return max_swizzle_size_; } -private: - int max_swizzle_size_ = 1; -}; - -template -auto make_iterator(T* ptr) { - return cute::recast_ptr(ptr); -} - -template -struct IsDefaultEpilogue { - static constexpr bool value = false; -}; - -template -struct IsDefaultEpilogue> { - static constexpr bool value = true; -}; - -template -struct IsDefaultEpilogue> { - static constexpr bool value = true; -}; - -// The number of splits to test. -// -// This class makes it harder to confuse the order of arguments -// of the various run(...) functions in this file. The constructor -// is explicit, so one can't just type 42 (or false, which the -// compiler unhelpfully turns into 0); one has to type Splits(42). -// Splits() picks the default number of splits, 1. -// -// The conversion-to-int operator (operator int()) MUST be explicit! -// Conversion to int MUST require static_cast. -// Otherwise, that defeats a key purpose of this class, -// which is to catch common errors of confusing the order -// of function arguments. -class Splits { -public: - Splits() = default; - - template && - !cute::is_same_v)) > - explicit Splits(IntegralNotBool splits) : splits_(splits) {} - explicit operator int() const { return splits_; } -private: - int splits_ = 1; -}; - -// The number of iterations to test. -// -// This class, like Splits above makes it harder to confuse -// the order of arguments of the various run(...) functions in this file. -// Iterations() picks the default number of iterations, 20. -class Iterations { -public: - Iterations() = default; - - template && - !cute::is_same_v)) > - explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} - explicit operator int() const { return iterations_; } -private: - int iterations_ = 20; -}; - -template -bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } - - else if (bits_input <= 6) { - scope_max = 2; - scope_min = -2; - } - - else if (bits_input <= 8) { - - if constexpr ( - cute::is_same_v){ - scope_max = 4; - scope_min = 1; - } - else { - - scope_max = 1; - scope_min = -1; - - } - - } - else{ - scope_max = 4; - scope_min = -4; - } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - - else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - } - - else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - - else if (dist_kind == cutlass::Distribution::AllOnes) { - cutlass::reference::host::TensorFill(view, Element(1)); - } - - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; -} - -// Looks at Cute Stride to check Row / Column Major -template -static constexpr bool is_row_or_col_major(){ - int stride_0 = int(cute::size<0>(Stride{})); - int stride_1 = int(cute::size<1>(Stride{})); - int depth = cute::depth(Stride{}); - return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); -} - - -// -// Default MMA input Operands : A , B -// -template< - class ScheduleType_, - class Gemm, - class ElementA_ = typename Gemm::GemmKernel::ElementA, - class ElementB_ = typename Gemm::GemmKernel::ElementB> -struct HostCollectiveMainloop { - // Kernel data types - using ElementA = ElementA_; - using StrideA = typename Gemm::GemmKernel::StrideA; - using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - - static constexpr bool IsGroupGemm = !cute::is_same_v; - - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - - cutlass::ComplexTransform TransformA = Gemm::kTransformA; - cutlass::ComplexTransform TransformB = Gemm::kTransformB; - - std::vector stride_a_host; - std::vector stride_b_host; - - cutlass::DeviceAllocation stride_a_device; - cutlass::DeviceAllocation stride_b_device; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - std::vector> tensors_A; - std::vector> tensors_B; - cutlass::DeviceAllocation device_tensors_A; - cutlass::DeviceAllocation device_tensors_B; - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() - ): - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - init_A(init_A_), init_B(init_B_), seed(seed_), - check_relative_equality(check_relative_equality_) { } - - bool initialize(ProblemShapeType problem_shapes) { - // - // Allocate the GEMM workspace - // - // for pointer array problem_shapes.groups() is 1 - - tensors_A.clear(); - tensors_B.clear(); - stride_a_host.clear(); - stride_b_host.clear(); - - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = cutlass::platform::max(problem_shapes.groups(), L); - - for(int32_t i = 0; i < L; ++i) { - auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - - stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); - stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto a_coord = cutlass::make_Coord(M, K); - // Cutlass has Row/Col major refers to MxK times KxN matrix product, - // so the HostTensorB should be treated as KxN in "coord"'s view - auto b_coord = cutlass::make_Coord(K, N); - - tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); - tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); - - EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); - EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensors_A[i].host_view().at({0, 0}) = ElementA(1); - tensors_B[i].host_view().at({0, 0}) = ElementB(1); - - tensors_A[i].sync_device(); - tensors_B[i].sync_device(); - } - - return true; - } - - Arguments to_args(ProblemShapeType problem_shapes) { - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = cutlass::platform::max(problem_shapes.groups(), L); - - std::vector ptr_A_host(L); - std::vector ptr_B_host(L); - - for (int32_t i = 0; i < L; ++i) { - ptr_A_host.at(i) = tensors_A[i].device_data(); - ptr_B_host.at(i) = tensors_B[i].device_data(); - } - - device_tensors_A.reset(L); - device_tensors_A.copy_from_host(ptr_A_host.data()); - - device_tensors_B.reset(L); - device_tensors_B.copy_from_host(ptr_B_host.data()); - - stride_a_device.reset(problem_shapes.groups()); - stride_a_device.copy_from_host(stride_a_host.data()); - stride_b_device.reset(problem_shapes.groups()); - stride_b_device.copy_from_host(stride_b_host.data()); - - Arguments arguments; - - if constexpr (IsGroupGemm) { - arguments - = - { - device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() - }; - } - else { - arguments = - { - device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] - }; - } - - return arguments; - } - - auto to_host_args(ProblemShapeType problem_shapes, int batch) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); - auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), - make_layout(make_shape(M, K, 1), stride_a_host[batch])); - auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), - make_layout(make_shape(N, K, 1), stride_b_host[batch])); - - cutlass::reference::host::GettMainloopParams mainloop_params{}; - - mainloop_params.A = A; - mainloop_params.B = B; - mainloop_params.transform_A = TransformA; - mainloop_params.transform_B = TransformB; - - return mainloop_params; - } - - void print_tensors(std::ofstream& file, int batch) { - file << "A =\n" << tensors_A[batch].host_view() - << "\nB =\n" << tensors_B[batch].host_view(); - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - ProblemShapeType problem_shapes, int batch) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); - - bool passed = true; - return passed; - } -}; - - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - // Kernel data types - using ElementA = ElementA_; - using StrideA = typename Gemm::GemmKernel::StrideA; - using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; - using ElementB = ElementB_; - using StrideB = typename Gemm::GemmKernel::StrideB; - using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; - using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; - - static constexpr bool IsGroupGemm = !cute::is_same_v; - - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; - - static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; - - using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; - using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; - using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; - using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; - using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; - using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; - - using Arguments = typename Gemm::GemmKernel::MainloopArguments; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - - std::vector stride_a_host; - std::vector stride_b_host; - cutlass::DeviceAllocation stride_a_device; - cutlass::DeviceAllocation stride_b_device; - - std::vector layout_sfa_host; - std::vector layout_sfb_host; - cutlass::DeviceAllocation layout_sfa_device; - cutlass::DeviceAllocation layout_sfb_device; - - typename LayoutTagA::Stride stride_factor_A; - typename LayoutTagB::Stride stride_factor_B; - - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - - std::vector> tensors_A; - std::vector> tensors_B; - std::vector> tensors_SFA; - std::vector> tensors_SFB; - - cutlass::DeviceAllocation device_tensors_A; - cutlass::DeviceAllocation device_tensors_B; - cutlass::DeviceAllocation device_tensors_SFA; - cutlass::DeviceAllocation device_tensors_SFB; - - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - // Note: this limitation comes from testbed / not the library - static_assert(is_row_or_col_major(), - "ERROR : A Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : B Layout is neither Row / Column Major)"); - - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed, - typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), - typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() - ): - check_relative_equality(check_relative_equality_), - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - init_A(init_A_), init_B(init_B_), seed(seed_) { } - - template - bool initialize(ProblemShapeType problem_shapes) { - // - // Allocate the GEMM workspace - // - - tensors_A.clear(); - tensors_B.clear(); - stride_a_host.clear(); - stride_b_host.clear(); - tensors_SFA.clear(); - tensors_SFB.clear(); - layout_sfa_host.clear(); - layout_sfb_host.clear(); - - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = std::max(problem_shapes.groups(), L); - - for (int32_t i = 0; i < L; ++i) { - auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - - stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); - stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto a_coord = cutlass::make_Coord(M, K); - // Cutlass has Row/Col major refers to MxK times KxN matrix product, - // so the HostTensorB should be treated as KxN in "coord"'s view - auto b_coord = cutlass::make_Coord(K, N); - - tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); - tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); - - EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); - EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensors_A[i].host_view().at({0, 0}) = ElementA(1); - tensors_B[i].host_view().at({0, 0}) = ElementB(1); - - tensors_A[i].sync_device(); - tensors_B[i].sync_device(); - - using namespace cute; - - auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); - auto m_blks = cutlass::ceil_div(M, Blk_MN{}); - auto n_blks = cutlass::ceil_div(N, Blk_MN{}); - layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); - layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); - auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); - - tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); - tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); - - EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); - EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); - tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); - - tensors_SFA[i].sync_device(); - tensors_SFB[i].sync_device(); - } - - return true; - } - - Arguments to_args(ProblemShapeType problem_shapes) { - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = std::max(problem_shapes.groups(), L); - - std::vector ptr_A_host(L); - std::vector ptr_B_host(L); - std::vector ptr_SFA_host(L); - std::vector ptr_SFB_host(L); - - for (int32_t i = 0; i < L; ++i) { - ptr_A_host.at(i) = tensors_A[i].device_data(); - ptr_B_host.at(i) = tensors_B[i].device_data(); - ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); - ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); - } - - device_tensors_A.reset(L); - device_tensors_A.copy_from_host(ptr_A_host.data()); - - device_tensors_B.reset(L); - device_tensors_B.copy_from_host(ptr_B_host.data()); - - device_tensors_SFA.reset(L); - device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); - - device_tensors_SFB.reset(L); - device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); - - stride_a_device.reset(problem_shapes.groups()); - stride_a_device.copy_from_host(stride_a_host.data()); - - stride_b_device.reset(problem_shapes.groups()); - stride_b_device.copy_from_host(stride_b_host.data()); - - layout_sfa_device.reset(problem_shapes.groups()); - layout_sfa_device.copy_from_host(layout_sfa_host.data()); - - layout_sfb_device.reset(problem_shapes.groups()); - layout_sfb_device.copy_from_host(layout_sfb_host.data()); - - if constexpr (IsGroupGemm) { - return Arguments{ - device_tensors_A.get(), stride_a_device.get(), - device_tensors_B.get(), stride_b_device.get(), - device_tensors_SFA.get(), layout_sfa_device.get(), - device_tensors_SFB.get(), layout_sfb_device.get() - }; - } - else { - return Arguments{ - device_tensors_A.get(), stride_a_host[0], - device_tensors_B.get(), stride_b_host[0], - device_tensors_SFA.get(), layout_sfa_host[0], - device_tensors_SFB.get(), layout_sfb_host[0] - }; - } - } - - auto to_host_args(ProblemShapeType problem_shapes, int batch) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); - auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), - make_layout(make_shape(M, K, 1), stride_a_host[batch])); - auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); - - auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), - make_layout(make_shape(N, K, 1), stride_b_host[batch])); - auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); - - return cutlass::reference::host::GettMainloopParams - {A, SfA, B, SfB}; - } - - void print_tensors(std::ofstream& file, int batch) { - file << "A =\n" << tensors_A[batch].host_view() - << "\nB =\n" << tensors_B[batch].host_view() - << "\nSFA =\n" << tensors_SFA[batch].host_view() - << "\nSFB =\n" << tensors_SFB[batch].host_view(); - } - - bool compare_reference( - ProblemShapeType problem_shapes, int batch) { - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); - return true; - } -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -// -// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB -// -template< - class Gemm, - int SchedulerPipelineStageCount_, - int AccumulatorPipelineStageCount_, - class ElementA_, - class ElementB_ -> -struct HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> : public - HostCollectiveMainloop, - Gemm, ElementA_, ElementB_> { - using Base = HostCollectiveMainloop, - Gemm, ElementA_, ElementB_>; - HostCollectiveMainloop( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - uint64_t seed_ = Base::kDefaultSeed, - typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), - typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() - ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} -}; - -template -struct HostCollectiveDefaultEpilogue { - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using kernel = typename Gemm::GemmKernel; - using Epilogue = typename kernel::CollectiveEpilogue; - - using ElementD = typename kernel::ElementD; - using StrideD = typename kernel::StrideD; - using InternalStrideD = typename kernel::InternalStrideD; - using ElementC = non_void_t; - using StrideC = typename kernel::StrideC; - using InternalStrideC = typename kernel::InternalStrideC; - - static constexpr bool IsGroupGemm = !cute::is_same_v; - - using FusionOp = typename Gemm::EpilogueOutputOp; - - static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static_assert(is_row_or_col_major(), - "ERROR : C Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : D Layout is neither Row / Column Major)"); - - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementAccumulator = typename kernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename kernel::ProblemShape; - using ElementCompute = typename ElementComputeType::Type; - using ElementScalar = typename ElementScalarType::Type; - - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; - - /// Initialization - cutlass::DeviceAllocation stride_c_device; - cutlass::DeviceAllocation stride_d_device; - - std::vector stride_c_host; - std::vector stride_d_host; - - typename LayoutTagC::Stride stride_factor_C; - typename LayoutTagD::Stride stride_factor_D; - - // Inputs - ElementScalar alpha; - ElementScalar beta; - - std::vector> tensors_C; - std::vector> tensors_D; - std::vector> references_D; - cutlass::DeviceAllocation device_tensors_C; - cutlass::DeviceAllocation device_tensors_D; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - // Are scalars copied to device memory before kernel launch - ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; - // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector - VectorScale vector_scale_mode = VectorScale::DISABLED; - - cutlass::Distribution::Kind init_C; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - HostCollectiveDefaultEpilogue( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), - stride_factor_D(typename LayoutTagD::Stride()), - check_relative_equality(check_relative_equality_), - use_device_scalars(use_device_scalars_){ } - - bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { - // Initialize Epilogue tensors - - tensors_C.clear(); - tensors_D.clear(); - references_D.clear(); - stride_c_host.clear(); - stride_d_host.clear(); - - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = cutlass::platform::max(problem_shapes.groups(), L); - - for (int32_t i = 0; i < L; ++i) { - auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - - stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); - stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); - - // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode - auto c_coord = cutlass::make_Coord(M, N); - - tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); - tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); - references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); - EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); - tensors_C[i].host_view().at({0, 0}) = ElementC(1); - - cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); - tensors_C[i].sync_device(); - tensors_D[i].sync_device(); - } - alpha = alpha_; - beta = beta_; - - return true; - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - ProblemShapeType problem_shapes, - ElementScalar alpha, - ElementScalar beta, - int batch) { - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = cutlass::platform::max(problem_shapes.groups(), L); - - tensors_D[batch].sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); - - if (tensors_D[batch].size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); - } - - if (references_D[batch].size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); - } - - bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); - if(!passed) { - std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); - L = cutlass::platform::max(problem_shapes.groups(), L); - - std::vector ptr_C_host(L); - std::vector ptr_D_host(L); - - for (int32_t i = 0; i < L; ++i) { - ptr_C_host.at(i) = tensors_C[i].device_data(); - ptr_D_host.at(i) = tensors_D[i].device_data(); - } - - device_tensors_C.reset(L); - device_tensors_C.copy_from_host(ptr_C_host.data()); - - device_tensors_D.reset(L); - device_tensors_D.copy_from_host(ptr_D_host.data()); - - stride_c_device.reset(problem_shapes.groups()); - stride_c_device.copy_from_host(stride_c_host.data()); - - stride_d_device.reset(problem_shapes.groups()); - stride_d_device.copy_from_host(stride_d_host.data()); - - Arguments arguments; - if constexpr (IsGroupGemm) { - arguments = - { - {alpha, beta}, - device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() - }; - } - else { - arguments = - { - {alpha, beta}, - device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] - }; - } - - return arguments; - } - - auto to_host_args(ProblemShapeType problem_shapes, int batch) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); - L = std::max(problem_shapes.groups(), L); - - auto coord_0 = cutlass::make_Coord(0); - auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); - auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); - - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D)> - epilogue_params{}; - - epilogue_params.C = C; - epilogue_params.D = D; - epilogue_params.alpha = alpha; - epilogue_params.beta = beta; - - return epilogue_params; - } -}; - -template -struct HostCollectiveEpilogue { - // fusion types are potentially void if the fusion is not supported - // helper so we don't try to construct HostTensor with void type - template - using non_void_t = cute::conditional_t, U, T>; - - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - using kernel = typename Gemm::GemmKernel; - using Epilogue = typename kernel::CollectiveEpilogue; - static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); - - using ElementD = typename kernel::ElementD; - using StrideD = typename kernel::StrideD; - using InternalStrideD = typename kernel::InternalStrideD; - using ElementC = non_void_t; - using StrideC = typename kernel::StrideC; - using InternalStrideC = typename kernel::InternalStrideC; - - static constexpr bool IsGroupGemm = !cute::is_same_v; - - static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - - static_assert(is_row_or_col_major(), - "ERROR : C Layout is neither Row / Column Major)"); - static_assert(is_row_or_col_major(), - "ERROR : D Layout is neither Row / Column Major)"); - - // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; - using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - using ElementAccumulator = typename kernel::ElementAccumulator; - using ElementScalingFactor = ElementAccumulator; - using ProblemShapeType = typename kernel::ProblemShape; - - // - // FusionOperation derived types/queries - // - using EpiloguePolicy = typename Epilogue::DispatchPolicy; - static constexpr bool IsLegacy = - cute::is_same_v< - EpiloguePolicy, - cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< - EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> - >; - - using FusionOp = typename Gemm::EpilogueOutputOp; - static_assert(cute::is_base_of_v); - - - // Scale factor Generation related - using SfStrategy = cutlass::reference::host::SfStrategy; - static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; - static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; - static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; - using ElementSFD = non_void_t, ElementD>; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< - SFD_VectorSize - >; - using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; - using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; - std::vector> tensors_SFD; - std::vector> references_SFD; - cutlass::DeviceAllocation device_tensors_SFD; - - using ElementCompute = typename FusionOp::ElementCompute; - using ElementScalar = typename FusionOp::ElementScalar; - using ElementBias = non_void_t; - using ElementAux = non_void_t; - using ElementAmax = non_void_t; - using LayoutTagAux = non_void_t; - using ActivationFunctor = non_void_t>; - - static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; - static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; - static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; - static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; - static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; - static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; - static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && - (cute::is_same_v || - cute::is_same_v); - static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && - (cute::is_same_v || - cute::is_same_v); - - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; - - /// Initialization - cutlass::DeviceAllocation stride_c_device; - cutlass::DeviceAllocation stride_d_device; - - std::vector stride_c_host; - std::vector stride_d_host; - - typename LayoutTagC::Stride stride_factor_C; - typename LayoutTagD::Stride stride_factor_D; - - // Inputs - cutlass::HostTensor alpha; - cutlass::HostTensor beta; - cutlass::HostTensor scale_A; - cutlass::HostTensor scale_B; - cutlass::HostTensor scale_C; - cutlass::HostTensor scale_D; - cutlass::HostTensor scale_Aux; - cutlass::HostTensor bias; - std::vector> tensors_C; - cutlass::DeviceAllocation device_tensors_C; - cutlass::HostTensor norm_constant; - - // Outputs - cutlass::HostTensor abs_max_Aux; - cutlass::HostTensor abs_max_D; - std::vector> tensors_Aux; - cutlass::DeviceAllocation device_tensors_Aux; - cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; - std::vector> tensors_D; - std::vector> references_D; - cutlass::DeviceAllocation device_tensors_D; - - // References - cutlass::HostTensor reference_dbias; - std::vector> references_Aux; - cutlass::HostTensor reference_abs_max_Aux; - cutlass::HostTensor reference_abs_max_D; - - // Whether to use relative equality checks - CheckEquality check_relative_equality = CheckEquality::EXACT; - // Are scalars copied to device memory before kernel launch - ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; - // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector - VectorScale vector_scale_mode = VectorScale::DISABLED; - - // Random distribution with which to initialize the A/B/C/D/Aux scaling factors - cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; - // Random distribution with which to initialize the bias vector - cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; - cutlass::Distribution::Kind init_C; - uint64_t seed; - static constexpr uint64_t kDefaultSeed = 4096; - - HostCollectiveEpilogue( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): init_scale(init_scale_), init_bias(init_bias_), - init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), - stride_factor_D(typename LayoutTagD::Stride()), - check_relative_equality(check_relative_equality_), - use_device_scalars(use_device_scalars_){ } - - bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { - // Initialize Epilogue tensors - - tensors_C.clear(); - tensors_D.clear(); - references_D.clear(); - stride_c_host.clear(); - stride_d_host.clear(); - - tensors_SFD.clear(); - references_SFD.clear(); - - - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = std::max(problem_shapes.groups(), L); - - for (int32_t i = 0; i < L; ++i) { - auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - - stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); - stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); - - auto c_coord = cutlass::make_Coord(M, N); - tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); - tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); - references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); - EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); - tensors_C[i].host_view().at({0, 0}) = ElementC(1); - - cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); - tensors_C[i].sync_device(); - tensors_D[i].sync_device(); - } - - auto scalar_coord = cutlass::make_Coord(1); - auto col_vector_coord = cutlass::make_Coord(M); - if constexpr (IsPerRowScaleEnabled) { - alpha.resize(col_vector_coord); - EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); - if (vector_scale_mode == VectorScale::DISABLED) { - beta.resize(scalar_coord, false); - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - else { - beta.resize(col_vector_coord); - EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); - } - } - else { - alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); - cutlass::reference::host::TensorFill(beta.host_view(), beta_); - } - alpha.sync_device(); - beta.sync_device(); - - if constexpr (IsScaleFactorEnabled) { - scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); - EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); - EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); - EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); - scale_A.sync_device(); - scale_B.sync_device(); - scale_C.sync_device(); - scale_D.sync_device(); - } - - if constexpr (IsBiasEnabled) { - bias.resize(col_vector_coord); - EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); - bias.sync_device(); - } - - if constexpr (IsDeBiasEnabled) { - bias.resize(col_vector_coord); - reference_dbias.resize(col_vector_coord); - cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); - cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); - bias.sync_device(); - } - - if constexpr (IsAbsMaxEnabledD) { - abs_max_D.resize(scalar_coord); - // ensure in-place device reductions perform their own initialization - cutlass::reference::host::TensorFill(abs_max_D.host_view(), - CUTLASS_STL_NAMESPACE::numeric_limits::max()); - abs_max_D.sync_device(); - reference_abs_max_D.resize(scalar_coord); - cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); - } - - tensors_Aux.clear(); - references_Aux.clear(); - - static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); - - if constexpr (IsAuxInEnabled) { - auto aux_coord = cutlass::make_Coord(M, N); - auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); - for (int32_t i = 0; i < L; ++i) { - tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); - EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); - tensors_Aux[i].sync_device(); - } - stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); - } - - static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); - - if constexpr (IsAuxOutEnabled) { - for (int32_t i = 0; i < L; ++i) { - auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - auto aux_coord = cutlass::make_Coord(M, N); - auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); - tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); - references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); - tensors_Aux[i].sync_device(); - } - - stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); - - if constexpr (IsScaleFactorEnabled) { - scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); - EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); - scale_Aux.sync_device(); - } - - if constexpr (IsAbsMaxEnabledAux) { - abs_max_Aux.resize(scalar_coord); - // ensure in-place device reductions perform their own initialization - cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), - CUTLASS_STL_NAMESPACE::numeric_limits::max()); - abs_max_Aux.sync_device(); - reference_abs_max_Aux.resize(scalar_coord); - cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); - } - } - - - if constexpr (IsBlockScaleSupported) { - for (int32_t i = 0; i < L; ++i) { - auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); - // If block scaled output is supported we always have at least 1 SFD - auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); - auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); - auto sfd_coord = [&] () { - return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); - }(); - tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); - references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); - tensors_SFD[i].sync_device(); - } - norm_constant.resize(scalar_coord, true); - EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); - norm_constant.sync_device(); - } - - - return true; - } - - template < - class Element, - class Layout - > - bool equality_check( - cutlass::TensorView const& lhs, - cutlass::TensorView const& rhs) const { - - // Factors used for calculating relative equality. CUTLASS's relative-equality - // checks in include/cutlass/relatively_equal.h are inspired by - // https://floating-point-gui.de/errors/comparison/. This reference suggests using - // the minimum normal value of a given type as the nonzero_floor. - Element epsilon(static_cast(0.1f)); - Element nonzero_floor(std::numeric_limits::min()); - - if constexpr (!cutlass::is_complex::value) { - if (check_relative_equality == CheckEquality::RELATIVE) { - return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, epsilon, nonzero_floor); - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - else { - return cutlass::reference::host::TensorEquals(lhs, rhs); - } - } - - bool compare_reference( - ProblemShapeType problem_shapes, - ElementScalar alpha, - ElementScalar beta, - int batch) { - tensors_D[batch].sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); - - if (tensors_D[batch].size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); - } - - if (references_D[batch].size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); - } - - bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); - if(!passed) { - std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); - L = std::max(problem_shapes.groups(), L); - - std::vector ptr_C_host(L); - std::vector ptr_D_host(L); - - for (int32_t i = 0; i < L; ++i) { - ptr_C_host.at(i) = tensors_C[i].device_data(); - ptr_D_host.at(i) = tensors_D[i].device_data(); - } - - device_tensors_C.reset(L); - device_tensors_C.copy_from_host(ptr_C_host.data()); - - device_tensors_D.reset(L); - device_tensors_D.copy_from_host(ptr_D_host.data()); - - stride_c_device.reset(problem_shapes.groups()); - stride_c_device.copy_from_host(stride_c_host.data()); - - stride_d_device.reset(problem_shapes.groups()); - stride_d_device.copy_from_host(stride_d_host.data()); - - std::vector ptr_Aux_host(L); - if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { - for (int32_t i = 0; i < L; ++i) { - ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); - } - device_tensors_Aux.reset(L); - device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); - } - - auto device_tensors_C_ptr = cute::is_void_v ? nullptr : - reinterpret_cast(device_tensors_C.get()); - - Arguments arguments; - if constexpr (IsGroupGemm) { - arguments = - { - {}, - device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() - }; - } - else { - arguments = - { - {}, - device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] - }; - } - - auto &fusion_args = arguments.thread; - if constexpr (IsLegacy) { - arguments.thread = { - alpha.at(coord_0), - beta.at(coord_0), - alpha.device_data(), - beta.device_data() - }; - arguments.ptr_Bias = bias.device_data(); - arguments.ptr_T = device_tensors_Aux.get(); - } - else { - fusion_args.alpha = alpha.at(coord_0); - fusion_args.beta = beta.at(coord_0); - - fusion_args.alpha_ptr = alpha.device_data(); - // can_implement requires beta_ptr to not be set if its voidC - fusion_args.beta_ptr = cute::is_void_v ? nullptr : - beta.device_data(); - - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_a = scale_A.at(coord_0); - fusion_args.scale_b = scale_B.at(coord_0); - fusion_args.scale_c = scale_C.at(coord_0); - fusion_args.scale_d = scale_D.at(coord_0); - fusion_args.scale_a_ptr = scale_A.device_data(); - fusion_args.scale_b_ptr = scale_B.device_data(); - fusion_args.scale_c_ptr = scale_C.device_data(); - fusion_args.scale_d_ptr = scale_D.device_data(); - } - - if constexpr (IsBiasEnabled) { - fusion_args.bias_ptr = bias.device_data(); - } - - if constexpr (IsDeBiasEnabled) { - fusion_args.dbias_ptr = bias.device_data(); - } - - // example of how to set kernel activation arguments - // see ActivationFunctor::Arguments in activation.h for definition - // if Arguments doesn't exist then fusion_args.activation is empty - if constexpr (cute::is_same_v>) { - fusion_args.activation.scale = ElementCompute(1); - } - - // Treat Clamp as ReLU - if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = 0; - fusion_args.activation.upper_bound = std::numeric_limits::max(); - } - - if constexpr (IsAbsMaxEnabledD) { - fusion_args.amax_D_ptr = abs_max_D.device_data(); - } - - if constexpr (IsAuxInEnabled) { - fusion_args.aux_ptr = device_tensors_Aux.get(); - fusion_args.dAux = stride_Aux; - } - - if constexpr (IsAuxOutEnabled) { - fusion_args.aux_ptr = device_tensors_Aux.get(); - fusion_args.dAux = stride_Aux; - if constexpr (IsScaleFactorEnabled) { - fusion_args.scale_aux = scale_Aux.at(coord_0); - fusion_args.scale_aux_ptr = scale_Aux.device_data(); - } - if constexpr (IsAbsMaxEnabledAux) { - fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); - } - } - - if constexpr (IsBlockScaleSupported) { - std::vector ptr_SFD_host(L); - for (int32_t i = 0; i < L; ++i) { - ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); - } - device_tensors_SFD.reset(L); - device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); - - arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); - arguments.thread.norm_constant_ptr = norm_constant.device_data(); - } - - } - - return arguments; - } - - auto to_host_args(ProblemShapeType problem_shapes, int batch) { - using namespace cute; - // - // Allocate the GEMM workspace - // - auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); - auto [M, N, K, L] = problem_shape_MNKL; - auto coord_0 = cutlass::make_Coord(0); - auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); - auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); - auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), - cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); - auto Aux = [&]() { - auto ptr = recast_ptr(nullptr); - if (IsAuxInEnabled) { - ptr = detail::make_iterator(tensors_Aux[batch].host_data()); - } else if (IsAuxOutEnabled) { - ptr = detail::make_iterator(references_Aux[batch].host_data()); - } - return cute::make_tensor(ptr, Aux_layout); - }(); - auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), - cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); - auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), - cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); - - auto SfD = [&](){ - if constexpr (IsBlockScaleSupported) { - auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), - Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); - return tensor; - } - else { - // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. - return D; - } - }(); - - - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D), - decltype(Bias), - decltype(Aux), - decltype(Valpha), - decltype(Vbeta), - ActivationFunctor - , decltype(SfD) - , Int - , cutlass::plus - , false - , SfGenStrategy - > epilogue_params{}; - - epilogue_params.C = C; - epilogue_params.D = D; - epilogue_params.alpha = alpha.at(coord_0); - epilogue_params.beta = beta.at(coord_0); - - if constexpr (IsScaleFactorEnabled) { - epilogue_params.scale_a = scale_A.at(coord_0); - epilogue_params.scale_b = scale_B.at(coord_0); - epilogue_params.scale_c = scale_C.at(coord_0); - epilogue_params.scale_d = scale_D.at(coord_0); - } - - if constexpr (IsBiasEnabled or IsDeBiasEnabled) { - epilogue_params.Bias = Bias; - } - - if constexpr (IsAbsMaxEnabledD) { - epilogue_params.abs_max_D = reference_abs_max_D.host_data(); - } - - if constexpr (IsAuxInEnabled) { - epilogue_params.Aux = Aux; - } - - if constexpr (IsAuxOutEnabled) { - epilogue_params.Aux = Aux; - if constexpr (IsScaleFactorEnabled) { - epilogue_params.scale_aux = scale_Aux.at(coord_0); - } - if constexpr (IsAbsMaxEnabledAux) { - epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); - } - } - - if constexpr (IsPerRowScaleEnabled) { - epilogue_params.Valpha = Valpha; - if (vector_scale_mode == VectorScale::ENABLED) { - epilogue_params.Vbeta = Vbeta; - } - } - - if constexpr (IsBlockScaleSupported) { - epilogue_params.SfD = SfD; - epilogue_params.st = norm_constant.at(coord_0); - } - - return epilogue_params; - } -}; - -template < - typename Gemm, - template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - bool force_legacy_epilogue = false, - typename ElementA = typename Gemm::GemmKernel::ElementA, - typename ElementB = typename Gemm::GemmKernel::ElementB -> -struct TestbedImpl { - // Kernel data types - using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; - // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type - using HostCollectiveMainloopType = HostCollectiveMainloop; - using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, - HostCollectiveDefaultEpilogue, - HostCollectiveEpilogue>; - - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementCompute = typename ElementComputeType::Type; - using ElementScalar = typename ElementScalarType::Type; - - using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; - using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; - using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; - using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; - - uint32_t sm_count; - // Used to force multi-wave tests for persistent kernel schedules - constexpr static int MaxSmCount = 16; - static constexpr uint64_t kDefaultSeed = 4096; - static constexpr uint32_t mma_promotion_interval = 4; - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - HostCollectiveMainloopType collective_mma_inputs; - CollectiveEpilogue collective_epilogue; - - static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; - - // - // Methods - // - - TestbedImpl( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), - collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } - - TestbedImpl( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), - collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } - - /// Initializes data structures - bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { - collective_mma_inputs.initialize(problem_shapes); - collective_epilogue.initialize(problem_shapes, alpha_, beta_); - - return true; - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - ProblemShapeType problem_shapes, - ElementScalar alpha, - ElementScalar beta, - int batch) - { - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); - - bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); - passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); - EXPECT_TRUE(passed); - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_" - << M << "x" << N << "x" << K << "x" << batch << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; - - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - collective_mma_inputs.print_tensors(file, batch); - collective_epilogue.print_tensors(file, batch); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - ProblemShapeType problem_shapes, - ElementScalar alpha, - ElementScalar beta) - { - using namespace cute; - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = std::max(problem_shapes.groups(), L); - - bool passed = true; - for (int32_t i = 0; i < L; ++i) { - auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); - auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - - passed &= compare_reference(problem_shapes, alpha, beta, i); - } - return passed; - } - - /// Determine if the CUDA device is sufficient to run the kernel - bool sufficient() { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); - - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - this->sm_count = properties.multiProcessorCount; - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - printf("failed due to smem_size\n"); - printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); - return false; - } - - return true; - } - - /// Executes one test - bool run( - ProblemShapeType problem_shapes, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - detail::Iterations iterations = detail::Iterations{} - ) - { - - // Fail test if insufficient CUDA device - if (!sufficient()) { - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } - - if (!this->initialize(problem_shapes, alpha, beta)) { - std::cerr << "Initialization failed \n"; - return false; - } - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = this->sm_count; - - typename HostCollectiveMainloopType::Arguments mainloop_args; - - mainloop_args = collective_mma_inputs.to_args(problem_shapes); - - if constexpr (IsGroupGemm) { - arguments = - { - cutlass::gemm::GemmUniversalMode::kGrouped, - problem_shapes, - mainloop_args, - collective_epilogue.to_args(problem_shapes), - hw_info - }; - } - else { - arguments = - { - cutlass::gemm::GemmUniversalMode::kArray, - problem_shapes, - mainloop_args, - collective_epilogue.to_args(problem_shapes), - hw_info - }; - } - - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return false; - } - - // - // Run the GEMM - // - - cudaError_t result; - status = gemm_op.initialize(arguments, workspace.get()); - status = gemm_op.run(); - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - bool passed = this->verify(problem_shapes, alpha, beta); - if (!passed) { - std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta - << "\n"; - } - - return passed; - } -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity, - bool force_legacy_epilogue = false, - typename ElementA = typename Gemm::GemmKernel::ElementA, - typename ElementB = typename Gemm::GemmKernel::ElementB -> -struct Testbed3x { - - using TestBedImpl = typename detail::TestbedImpl< - Gemm, - ActivationFunctor, - force_legacy_epilogue, - ElementA, - ElementB - >; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - - using ElementAccumulator = typename TestBedImpl::ElementAccumulator; - using ElementCompute = typename TestBedImpl::ElementCompute; - using ElementScalar = typename TestBedImpl::ElementScalar; - - using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - - static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; - - // Detail Implementation - TestBedImpl impl_; - - // - // Methods - // - Testbed3x( - CheckEquality check_relative_equality_ = CheckEquality::EXACT, - ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode_ = VectorScale::DISABLED, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed) - : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} - - /// Executes one test - bool run( - typename TestBedImpl::ProblemShapeType problem_shapes, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - detail::Iterations iterations = detail::Iterations{} - ) - { - return impl_.run( - problem_shapes, alpha, beta, iterations); - } -}; - -template < - typename Gemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity -> -bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { - using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorScale::DISABLED); - - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; - std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; - - constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; - constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - - std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - - int batches[] = {5, 10}; - - bool passed = true; - - for (int batch : batches) { - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - - if constexpr (Testbed3x::IsGroupGemm) { - std::vector problem_sizes_host; - cutlass::DeviceAllocation problem_sizes_device; - - for (int i = 0; i < batch; ++i) { - problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); - } - - problem_sizes_device.reset(problem_sizes_host.size()); - problem_sizes_device.copy_from_host(problem_sizes_host.data()); - - passed = testbed.run( - ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - else { - ProblemShapeType problem_size{{m, n, k, batch}}; - - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - - if (!passed) { - std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; - return false; - } - } // k - } // n - } // m - } // batch - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestSmall(double alpha = 1.0, double beta = 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED, - std::vector override_problem_size_k = {}) { - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; - using ElementA = typename Gemm::GemmKernel::ElementA; - using ElementB = typename Gemm::GemmKernel::ElementB; - using TiledMma = typename Gemm::GemmKernel::TiledMma; - - static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); - // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. - int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); - int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); - - int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); - int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); - - int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); - - if constexpr (apply_alignment_offset) { - // If BlockScaled, then min alignment is SFVecSize - static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; - static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; - if constexpr (IsBlockScaleSupported) { - alignment_input = cutlass::round_up(alignment_input, SFVecSize); - } - } - - - using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; - using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; - CtaShape_MNK cta_shape; - Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); - // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime - static constexpr int SmCount = 16; - - float waves[] = {0.5, 2.5}; - int batches[] = {3}; - int cluster_m = 1; - int cluster_n = 1; - - std::vector problem_size_k; - if (override_problem_size_k.empty()) { - // this is to test with min alignment - problem_size_k = {256 - alignment_input, 512 + alignment_input}; - } - else { - problem_size_k = override_problem_size_k; - } - - if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { - typename DispatchPolicy::ClusterShape cluster_shape; - cluster_m = cute::size<0>(cluster_shape); - cluster_n = cute::size<1>(cluster_shape); - } - - bool passed = true; - - for (int batch : batches) { - for (float wave : waves) { - for (int k : problem_size_k) { - int grid_m, grid_n = 0; - float num_grid = wave * SmCount; - - if (cluster_m >= cluster_n) { - grid_m = cluster_m; - grid_n = static_cast(num_grid) / grid_m; - // Align grid_n to cluster_n - grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); - } - else { - grid_n = cluster_n; - grid_m = static_cast(num_grid) / grid_n; - // Align grid_m to cluster_m - grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); - } - - int m = grid_m * cute::size<0>(cta_shape) - alignment_input; // this is just to test with unusual problem shapes - int n = grid_n * cute::size<1>(cta_shape) + alignment_input; - - if constexpr (Testbed3x::IsGroupGemm) { - std::vector problem_sizes_host; - cutlass::DeviceAllocation problem_sizes_device; - for (int i = 0; i < batch; ++i) { - problem_sizes_host.push_back({m * ((i % 2) + 1), n * ((i % 3) + 1), k * ((i % 2) + 1)}); - } - problem_sizes_device.reset(problem_sizes_host.size()); - problem_sizes_device.copy_from_host(problem_sizes_host.data()); - - ProblemShapeType problem_shapes{batch, problem_sizes_device.get(), problem_sizes_host.data()}; - - if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { - for (int i = 0; i < batch; ++i) { - std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape(i) << " \n"; - } - } - passed = testbed.run( - problem_shapes, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - else { - ProblemShapeType problem_shapes{{m, n, k, batch}}; - if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { - std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape() << " \n"; - } - passed = testbed.run( - problem_shapes, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - - if (!passed) { - std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; - return false; - } - } // k - } // waves - } // batches - - return passed; -} - -template -bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED) { - return TestSmall( - alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); -} - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp deleted file mode 100644 index 8b00f98a97846de175f1c6f95919c483ab4b81da..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ /dev/null @@ -1,515 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "testbed_utils.h" -#include "gemm_testbed_3x.hpp" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Testbed3xTensorBroadcast { - - using TestBedImpl = typename detail::TestbedImpl; - using Kernel = typename Gemm::GemmKernel; - using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - - using ElementA = typename Kernel::ElementA; - using StrideA = typename Kernel::StrideA; - using ElementB = typename Kernel::ElementB; - using StrideB = typename Kernel::StrideB; - using ElementC = typename Kernel::ElementC; - using StrideC = typename Kernel::StrideC; - using ElementD = typename Kernel::ElementD; - using StrideD = typename Kernel::StrideD; - - using ElementAccumulator = typename Kernel::ElementAccumulator; - using ElementCompute = typename Epilogue::ElementCompute; - using ElementScalar = typename Epilogue::ElementScalar; - using ProblemShapeType = typename Kernel::ProblemShape; - using ElementBias = typename Epilogue::ElementBias; - using ActivationFunctor = typename Epilogue::ActivationFunctor; - - static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; - static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; - static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; - - static constexpr bool PerColBias = Epilogue::PerColumnBias; - - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - using LayoutTagC = typename TestBedImpl::LayoutTagC; - using LayoutTagD = typename TestBedImpl::LayoutTagD; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - cutlass::HostTensor bias; - cutlass::HostTensor tensor_C1; - // tensor_C0 is taken from TestbedImpl's tensor_C - - - // Detail Implementation - TestBedImpl impl_; - - // - // Methods - // - Testbed3xTensorBroadcast( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, - init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } - - Testbed3xTensorBroadcast( - typename LayoutTagA::Stride stride_factor_A_, - typename LayoutTagB::Stride stride_factor_B_, - typename LayoutTagC::Stride stride_factor_C_, - typename LayoutTagD::Stride stride_factor_D_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorScale::ENABLED, - init_A_, - init_B_, - init_C_, - cutlass::Distribution::Uniform, - cutlass::Distribution::Uniform, - seed_) { } - - /// Initializes data structures - void initialize(ProblemShapeType problem_size) { - // - // Allocate the GEMM workspace for A/B/C/D tensor - // - impl_.initialize(problem_size); - } - - void initialize_bias(ProblemShapeType problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); - bias.resize(cutlass::Coord<1>(bias_size)); - - EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); - bias.sync_device(); - } - - void initialize_c1(ProblemShapeType problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - - auto c_coord = cutlass::make_Coord(M * L, N); - - tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); - EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); - tensor_C1.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta, - bool use_bias) - { - auto [M, N, K, L] = problem_shape_MNKL; - - impl_.collective_epilogue.tensor_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); - - if (impl_.collective_epilogue.tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); - } - - if (impl_.collective_epilogue.reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); - } - - bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - std::stringstream fname; - fname << "error_Gemm_device_broadcast" - << M << "x" << N << "x" << K << "x" << L << "_" - << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" - << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; - - std::ofstream file(fname.str()); - file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias - << ", per-col bias: " << PerColBias << "\n\n"; - - if (use_bias){ - file << "Bias = \n" << bias.host_view()<< "\n\n"; - } - - file - << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() - << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() - << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() - << "\nC1 =\n" << tensor_C1.host_view() - << "\n\nReference =\n" << impl_.collective_epilogue.reference_D.host_view() - << "\n\nComputed =\n" <(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - auto N = cute::get<1>(problem_shape_MNKL); - auto K = cute::get<2>(problem_shape_MNKL); - auto L = cute::get<3>(problem_shape_MNKL); - - auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), - cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); - auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), - cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); - auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); - auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), - cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); - auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); - auto C1 = cute::make_tensor(tensor_C1.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); - - // Create host workspace for output of testbed. This computes a portion of the epilogue: - // ref_compute_out = Activation(alpha * (A @ B) + bias) - cutlass::HostTensor ref_compute_out; - auto c_coord = cutlass::make_Coord(M * L, N); - ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); - auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); - - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - - // Use a dummy null tensor for operand C because the epilogue overrides C. - auto dummy_C = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); - ElementCompute dummy_beta(0); - auto dummy_Aux = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); - auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); - auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); - - auto dummy_SFD = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); - using DummySFDVectorSize = cute::Int<0>; - - - cutlass::reference::host::GettEpilogueParams< - ElementScalar, - ElementScalar, - ElementAccumulator, - ElementCompute, - decltype(dummy_C), - decltype(RefComputeOut), - decltype(Bias), - decltype(dummy_Aux), - decltype(dummy_Valpha), - decltype(dummy_Vbeta), - ActivationFunctor, - decltype(dummy_SFD), - DummySFDVectorSize, - cutlass::plus, - PerColBias> epilogue_params{ - alpha, - dummy_beta, - dummy_C, - RefComputeOut, - Bias, - dummy_Aux, - dummy_Valpha, - dummy_Vbeta - }; - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - - cutlass::NumericConverter source_converter; - cutlass::NumericConverter destination_converter; - cutlass::multiplies mul; - - // Compute broadcast operations atop the reference - #pragma omp parallel for collapse(3) - for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { - for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { - ElementCompute intermediate = RefComputeOut(m, n, l); - // Apply BinaryOp0, if needed - if constexpr (IsBinaryOp0Enabled) { - typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; - ElementCompute converted_source = source_converter(C0(m, n, l)); - intermediate = bin0(intermediate, mul(beta, converted_source)); - } - - // Apply BinaryOp1, if needed - if constexpr (IsBinaryOp1Enabled) { - typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; - ElementCompute converted_source = source_converter(C1(m, n, l)); - intermediate = bin1(intermediate, mul(beta, converted_source)); - } - - // Apply UnaryOp, if needed - if constexpr (IsUnaryOpEnabled) { - typename Epilogue::ThreadEpilogueOp::UnaryOp unary; - intermediate = unary(intermediate); - } - - D(m, n, l) = destination_converter(intermediate); - } - } - } - - return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); - } - - /// Executes one test - bool run( - ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - int iterations = 20, - bool use_bias = true) - { - // Fail test if insufficient CUDA device - if (!impl_.sufficient()) { - std::cout << "Test failed due to insufficient CUDA device." << std::endl; - return false; - } - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - if (not profiling) { - impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); - hw_info.sm_count = impl_.sm_count; - } - else { - impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - hw_info.sm_count = impl_.sm_count; - } - - /// Initializes data structures - /// A/B/C0/D Tensor - initialize(problem_size); - initialize_bias(problem_size); - - if constexpr (IsBinaryOp1Enabled) { - initialize_c1(problem_size); - } - - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, - impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, - impl_.mma_promotion_interval - }, - { // Epilogue arguments - { alpha, beta }, // ThreadOp arguments - impl_.collective_epilogue.stride_c, - impl_.collective_epilogue.tensor_D.device_data(), - impl_.collective_epilogue.stride_d, - use_bias ? bias.device_data() : nullptr, - impl_.collective_epilogue.tensor_C.device_data(), - tensor_C1.device_data() - }, // Epilogue arguments end - hw_info - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // - // Run the GEMM - // - - if (profiling) { - return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); - } - else { - cudaError_t result; - status = gemm_op.initialize(arguments, workspace.get()); - status = gemm_op.run(); - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; - return false; - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - bool passed = this->verify(problem_size, alpha, beta, use_bias); - if (!passed) { - std::cout << "Error : Failed : with alpha: " << float(alpha) - << ", beta: " << float(beta) - << ", use_bias: " << use_bias - << "\n"; - } - - return passed; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllTensorBroadcast(bool use_bias=true) { - using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; - std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; - - if constexpr (cute::is_same_v) { - problem_size_m.push_back(768); - problem_size_n.push_back(768); - } - - constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; - constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - - std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - - Testbed3xTensorBroadcast testbed; - bool passed = true; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - for (bool use_bias : {true, false}) { - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(1), - false, // profiling - 20, // iterations - use_bias - ); - - if (!passed) { - return false; - } - } - } - } - } - - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(1), - false, // profiling - 20 // iterations - ); - if (!passed) { - return false; - } - } - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h deleted file mode 100644 index 6ae7b864cb272782da4920ffc038830d3b5984b2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h +++ /dev/null @@ -1,300 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/tensor_view_io.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -//////////////////////////////////////////////////////////////////////////////// - -template -struct MultistageTestbed { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = - typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - // - // Methods - // - - MultistageTestbed( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080) - : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} - - /// Helper to initialize a tensor view - template - bool initialize_tensor(cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { - int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; - cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, - -scope, 0); - } else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); - } else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - } else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), - view.capacity()); - } else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Waives test if CUDA device is insufficient - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run(cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waives test if CUDA device is insufficient - if (!sufficient()) { - return true; - } - - // - // Allocate the GEMM workspace - // - - cutlass::HostTensor - tensor_A(problem_size.mk()); - - cutlass::HostTensor - tensor_B(problem_size.kn()); - - cutlass::HostTensor - tensor_C(problem_size.mn()); - - cutlass::HostTensor - tensor_D(problem_size.mn()); - - cutlass::HostTensor - reference_D(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), - tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, tensor_A.device_ref(), tensor_B.device_ref(), - tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.initialize(arguments); - - if (status != cutlass::Status::kSuccess) { - cudaError_t error = cudaGetLastError(); - std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; - return true; - } - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Verify - // - - cutlass::reference::host::Gemm< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, - ElementAccumulator, typename Gemm::Operator> - reference_gemm; - - reference_gemm( - problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, - reference_D.host_ref(), ElementAccumulator(0)); - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - bool passed = cutlass::reference::host::TensorEquals( - reference_D.host_view(), tensor_D.host_view()); - - EXPECT_TRUE(passed); - if (!passed) { - std::stringstream fname; - - fname << "error_Gemm_device_" << problem_size.m() << "x" - << problem_size.n() << "x" << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN - << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM - << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK - << ".txt"; - - std::ofstream file(fname.str()); - - file << "problem: " << problem_size << ", alpha: " << alpha - << ", beta: " << beta << "\n\n"; - - file << "A =\n" - << tensor_A.host_view() << "\nB =\n" - << tensor_B.host_view() << "\nC =\n" - << tensor_C.host_view() << "\n\nReference =\n" - << reference_D.host_view() << "\nComputed =\n" - << tensor_D.host_view(); - } - - return passed; - } - - /// Runs a set of problem sizes - bool run_all() { - bool passed = true; - - int problem_size_m[] = {16, 528}; - - int problem_size_n[] = {16, 528}; - - int problem_size_k[] = {Gemm::InstructionShape::kK, - Gemm::ThreadblockShape::kK * Gemm::kStages + - Gemm::InstructionShape::kK}; - - double problem_alpha[] = {1.0}; - - // TODO Try non zero beta value after multistaged epilogue is implemented - double problem_beta[] = {0.0}; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (double alpha : problem_alpha) { - for (double beta : problem_beta) { - passed = - run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); - - if (!passed) { - return false; - } - } - } - } - } - } - - return true; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h deleted file mode 100644 index e309208bb4311253be5b7366841164eb62748bab..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h +++ /dev/null @@ -1,348 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/host_reorder.h" - -namespace test { -namespace gemm { -namespace device { - -//////////////////////////////////////////////////////////////////////////////// - -template -struct MultistageInterleavedTestbed { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - // - // Methods - // - - MultistageInterleavedTestbed( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, 2, -2, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerMultiprocessor < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - // - // Allocate the GEMM workspace - // - - cutlass::HostTensor< - typename Gemm::ElementA, - typename Gemm::LayoutA> tensor_A(problem_size.mk()); - - cutlass::HostTensor< - typename Gemm::ElementB, - typename Gemm::LayoutB> tensor_B(problem_size.kn()); - - cutlass::HostTensor< - typename Gemm::ElementB, - typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> tensor_C(problem_size.mn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> tensor_D(problem_size.mn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> reference_D(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - cutlass::reorder_column( - tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); - - cutlass::reference::host::TensorCopy( - reference_D.host_view(), - tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B_reordered.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, - tensor_A.device_ref(), - tensor_B_reordered.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), - {alpha, beta} - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.initialize(arguments); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Verify - // - - cutlass::reference::host::Gemm< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, - ElementAccumulator, typename Gemm::Operator> - reference_gemm; - - reference_gemm( - problem_size, - alpha, - tensor_A.host_ref(), - tensor_B.host_ref(), - beta, - reference_D.host_ref(), - ElementAccumulator(0) - ); - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - bool passed = cutlass::reference::host::TensorEquals( - reference_D.host_view(), - tensor_D.host_view()); - - EXPECT_TRUE(passed); - if (!passed) { - - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nB_reordered =\n" << tensor_B_reordered.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view(); - } - - return passed; - } - - /// Runs a set of problem sizes - bool run_all() { - bool passed = true; - - int problem_size_m[] = { - InterleavedK, 512 + InterleavedK - }; - - int problem_size_n[] = { - InterleavedK, 512 + InterleavedK - }; - - int problem_size_k[] = { - InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK - }; - - double problem_alpha[] = { - 1.0 - }; - - double problem_beta[] = { - 0.0 - }; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (double alpha : problem_alpha) { - for (double beta : problem_beta) { - - passed = run( - {m, n, k}, - ElementCompute(alpha), - ElementCompute(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - - return true; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py deleted file mode 100644 index a180028205abb689436c73403eea82758ade7da9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# this file creates the test/unit/gemm/device simt tests - - -outputDir = "" - -################################################################################ -# parameters -# Edge - for tiles, the edges represent the length of one side -# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles -# MaxEdge - maximum length of each edge -# Min/Max - minimum/maximum of the product of edge lengths -################################################################################ - -warpsPerThreadblockEdge = [1, 2, 4, 8, 16] -warpsPerThreadblockRatio = 2 -warpsPerThreadblockMax = 16 -# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases - -warpShapeEdges = [8, 16, 32, 64, 128, 256] -warpShapeRatio = 4 -warpShapeMax = 64*64 -warpShapeMin = 8*8 - -threadblockEdgeMax = 256 - -# char, type bits/elem, max tile, L0 threadblock tiles -precisions = [ - ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], - ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], - ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], - ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], - ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], - ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], - ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], - ] -# L1 will have a single kernel for every unique shape -# L2 will have everything else - -transposes = [ - [False, False], - [False, True], - [True, False], - [True, True] - ] - -################################################################################ -# warps per threadblock -################################################################################ -warpsPerThreadblocks = [] -for warpsPerThreadblock0 in warpsPerThreadblockEdge: - for warpsPerThreadblock1 in warpsPerThreadblockEdge: - if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: - warpsPerThreadblocks.append([warpsPerThreadblock0, - warpsPerThreadblock1]) -print("WarpsPerThreadblocks",warpsPerThreadblocks) - -################################################################################ -# warp shapes -################################################################################ -warpNumThreads = 32 -warpShapes = [] -for warp0 in warpShapeEdges: - for warp1 in warpShapeEdges: - if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin: - warpShapes.append([warp0, warp1]) -print("WarpShapes", warpShapes) - -numL0 = 0 -numL1 = 0 -numL2 = 0 - -################################################################################ -# create kernels -# create a file for each precision/transpose -# each file contains many tile sizes -################################################################################ - -# precisions -for precision in precisions: - - # get precision char - precisionChar = precision[0] - precisionType = precision[1] - precisionBits = precision[2] - threadblockMaxElements = precision[3] - threadblockTilesL0 = precision[4] - - # transposes - for transpose in transposes: - - # get transpose char - columnMajorA = transpose[0] - columnMajorB = transpose[1] - transCharA = "n" if columnMajorA else "t" - transCharB = "n" if columnMajorB else "t" - - # open file - fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB) - print("\n", fileName) - filePath = "%s%s" % (outputDir, fileName) - out = open(filePath, "w+") - - # write file header - out.write("/***************************************************************************************************\n" -" * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n" -" * SPDX-License-Identifier: BSD-3-Clause \n" -" * \n" -" * Redistribution and use in source and binary forms, with or without \n" -" * modification, are permitted provided that the following conditions are met: \n" -" * \n" -" * 1. Redistributions of source code must retain the above copyright notice, this \n" -" * list of conditions and the following disclaimer. \n" -" * \n" -" * 2. Redistributions in binary form must reproduce the above copyright notice, \n" -" * this list of conditions and the following disclaimer in the documentation \n" -" * and/or other materials provided with the distribution. \n" -" * \n" -" * 3. Neither the name of the copyright holder nor the names of its \n" -" * contributors may be used to endorse or promote products derived from \n" -" * this software without specific prior written permission. \n" -" * \n" -" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" \n" -" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE \n" -" * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE \n" -" * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE \n" -" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL \n" -" * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR \n" -" * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER \n" -" * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, \n" -" * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \n" -" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \n" -" *\n" -" **************************************************************************************************/\n" -"/*! \\file\n" -" \\brief Tests for device-wide GEMM interface\n" -"*/\n" -"\n" -"#include \n" -"\n" -"#include \"cutlass/cutlass.h\"\n" -"#include \"cutlass/gemm/device/gemm.h\"\n" -"#include \"cutlass/numeric_types.h\"\n" -"\n" -"#include \"../../common/cutlass_unit_test.h\"\n" -"\n" -"#include \"cutlass/util/host_tensor.h\"\n" -"#include \"cutlass/util/tensor_view_io.h\"\n" -"#include \"cutlass/util/reference/host/tensor_fill.h\"\n" -"#include \"cutlass/util/reference/host/tensor_copy.h\"\n" -"#include \"cutlass/util/reference/host/tensor_compare.h\"\n" -"#include \"cutlass/util/reference/host/gemm.h\"\n" -"\n" -"#include \"testbed.h\"\n" -"\n") - foundThreadblockTilesL0 = {} - foundThreadblockTilesL1 = {} - - ######################################################################## - # for each combination of tile sizes - ######################################################################## - for warpsPerThreadblock in warpsPerThreadblocks: - for warpShape in warpShapes: - warpThreadsM = 0 - if warpShape[0] > warpShape[1]: - warpThreadsM = 8 - else: - warpThreadsM = 4 - warpThreadsN = warpNumThreads / warpThreadsM - - # skip shapes with conflicting rectangularity - # they are unlikely to be fastest - blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] - blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] - warpG = warpShape[0] > warpShape[1] - warpL = warpShape[0] < warpShape[1] - - blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 - blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] - warpG2 = warpShape[0] > warpShape[1]*2 - warpL2 = warpShape[0]*2 < warpShape[1] - - if blockG2 and warpL: continue - if blockL2 and warpG: continue - if warpG2 and blockL: continue - if warpL2 and blockG: continue - - # check threadblock ratios and max - threadblockTile = [warpShape[0]*warpsPerThreadblock[0], - warpShape[1]*warpsPerThreadblock[1]] - if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue - if threadblockTile[0] > threadblockEdgeMax: continue - if threadblockTile[1] > threadblockEdgeMax: continue - totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] - - # calculate unroll - # ensure that every iteration at least a full load of A,B are done - unrollMin = 8 - unrollMin0 = totalThreads / threadblockTile[0] - unrollMin1 = totalThreads / threadblockTile[1] - unroll = max(unrollMin, unrollMin0, unrollMin1) - - threadTileM = warpShape[0] / warpThreadsM - threadTileN = warpShape[1] / warpThreadsN - if threadTileM < 2 or threadTileN < 2: continue - if threadTileM*threadTileN*precisionBits > 8*8*32: continue - - # epilogue currently only supports N < WarpNumThreads - if threadblockTile[1] < warpNumThreads: continue - - # limit smem - smemBitsA = threadblockTile[0]*unroll*2*precisionBits - smemBitsB = threadblockTile[1]*unroll*2*precisionBits - smemKBytes = (smemBitsA+smemBitsB)/8/1024 - if (smemKBytes > 48): continue - - # test level 0 - testLevel = -1 - for tileId in range(0, len(threadblockTilesL0)): - tbTile = threadblockTilesL0[tileId] - if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]: - if tuple(tbTile) not in foundThreadblockTilesL0: - testLevel = 0 - numL0 += 1 - foundThreadblockTilesL0[tuple(tbTile)] = True - - # test level 1 - if testLevel < 0: - threadblockTileAlreadyUsed = False - if tuple(threadblockTile) not in foundThreadblockTilesL1: - testLevel = 1 - numL1 += 1 - foundThreadblockTilesL1[tuple(threadblockTile)] = True - - # test level 2 - if testLevel < 0: - testLevel = 2 - numL2 += 1 - - ################################################################ - # write this tile to file - ################################################################ - - print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % ( - threadblockTile[0], threadblockTile[1], unroll, - threadTileM, threadTileN, - warpThreadsM, warpThreadsN, - warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel)) - - out.write("////////////////////////////////////////////////////////////////////////////////\n" - "// Elements / Thread: %3i x %3i\n" - "// Threads / Warp: %3i x %3i\n" - "// Warps / Block: %3i x %3i\n" - "// Threadblock: %3i x %3i x %2i\n" - % ( threadTileM, threadTileN, - warpThreadsM, warpThreadsN, - warpsPerThreadblock[0], warpsPerThreadblock[1], - threadblockTile[0], threadblockTile[1], unroll - ) - ) - - out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % ( - testLevel, - precisionChar, - transCharA, - transCharB, - threadblockTile[0], - threadblockTile[1], - unroll, - warpShape[0], - warpShape[1], - threadTileM, - threadTileN, - warpThreadsM, - warpThreadsN, - warpsPerThreadblock[0], - warpsPerThreadblock[1] - )) - out.write(" using precision = %s;\n" % precisionType) - out.write(" using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % ( - threadblockTile[0], - threadblockTile[1], - unroll)) - out.write(" using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % ( - warpShape[0], - warpShape[1], - unroll)) - out.write(" static int const kEpilogueElementsPerAccess = 1;\n" - " using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n" - " using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n" - " precision, kEpilogueElementsPerAccess, precision, precision>;\n\n") - - out.write(" using Gemm = cutlass::gemm::device::Gemm<\n" - " precision, cutlass::layout::%sMajor,\n" - " precision, cutlass::layout::%sMajor,\n" - " precision, cutlass::layout::RowMajor,\n" - " precision,\n" - " cutlass::arch::OpClassSimt,\n" - " cutlass::arch::Sm50,\n" - " ThreadblockShape, WarpShape, InstructionShape,\n" - " EpilogueOutputOp,\n" - " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" - " 2 // Stages\n" - " >;\n" % ( - "Column" if columnMajorA else "Row", - "Column" if columnMajorB else "Row", - )) - out.write(" EXPECT_TRUE(test::gemm::device::TestAllGemm());\n" - "} )\n\n") - - - out.close() -print("NumKernels:", numL0, numL1, numL2) - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp deleted file mode 100644 index 63ffc3281dd2b9e9f74e0024c73da00628331dd4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp +++ /dev/null @@ -1,545 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Host reference and operations for Sm90 EVT unit test -*/ -#pragma once -#include "gemm_testbed_3x_evt.hpp" - -////////////////////////////////////////////////////////////////////////////// -/// Host references used for testing -namespace test::gemm::device { -template -using HEVT = HostTreeVisitor; - -template -using HDAG = HostTopoVisitor; - -template -using HST = HostSplitTreeVisitor; - -/// D = alpha * acc + beta * C + AuxLoad -template -class HostEVTAuxLoad { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using ElementD = typename Gemm::GemmKernel::ElementC; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using ScalarAlpha = HostScalarBroadcast<1>; - using AccFetchNode = HostAccumulator<>; - using AuxLoadNode = HostAuxLoad; - using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; - using ScalarBeta = HostScalarBroadcast<1>; - using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; - using EVTModule = HEVT, TernaryCompute1>; -}; - -/// D = alpha * acc + beta * C + per-column bias -template -class HostPerColBias { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using ElementD = typename Gemm::GemmKernel::ElementC; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using ScalarAlpha = HostScalarBroadcast<1>; - using AccFetchNode = HostAccumulator<>; - using RowBroadcastNode = HostRowBroadcast; - using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; - using ScalarBeta = HostScalarBroadcast<1>; - using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; - using EVTModule = HEVT, TernaryCompute1>; -}; - -/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) -/// Testing EVT - DAG structure -template -class HostEVTDAG { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using ElementD = typename Gemm::GemmKernel::ElementC; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using ScalarAlpha = HostScalarBroadcast<1>; - using AccFetchNode = HostAccumulator<>; - using AuxLoadNode = HostAuxLoad; - using DAGNode = HDAG< - float, - cute::tuple< - cute::tuple<>, // 0. alpha - cute::tuple<>, // 1. acc - cute::tuple<>, // 2. aux load - cute::tuple, // 3. alpha * acc + aux load - cute::tuple, // relu(alpha * acc + aux load) - cute::tuple // relu(alpha * acc + aux load) + aux load - >, - ScalarAlpha, - AccFetchNode, - AuxLoadNode, - HostCompute, - HostCompute, - HostCompute - >; - using ScalarBeta = HostScalarBroadcast<1>; - using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; - using EVTModule = HEVT, TernaryCompute1>; -}; - -/// EVT = alpha * acc + C -/// D = Graph(maximum(EVT + per-row bias, EVT)) -/// Testing DAG - EVT -template -class HostDAGEVT { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using ElementD = typename Gemm::GemmKernel::ElementC; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using EVTNode = HEVT< - HostAuxStore, - HEVT< - HostCompute, - HostScalarBroadcast<2>, - HostAccumulator<>, - HostAuxLoad - > - >; - using EVTModule = HEVT< - HostAuxStore, - HDAG< - float, - cute::tuple< - cute::tuple<>, // 0. EVT - cute::tuple<>, // 1. per-row bias - cute::tuple, // 2. EVT + per-row bias - cute::tuple // 3. maximum(EVT + per-row bias, EVT) - >, - EVTNode, - HostColBroadcast>, - HostCompute, - HostCompute - > - >; -}; - -/// Xreduce(alpha * acc + beta * C) -template -class HostReduce { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using ElementD = typename Gemm::GemmKernel::ElementC; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using ScalarAlpha = HostScalarBroadcast<1>; - using AccFetchNode = HostAccumulator<>; - using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; - using ScalarBeta = HostScalarBroadcast<1>; - using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; - using ReduceNode = HEVT; - using EVTModule = HEVT, ReduceNode>; -}; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template class ActivationFn, class ElementD> -class HostScaledLinCombPerRowBiasEltAct { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - using EVTModule = HEVT< - HostAuxStore, - HEVT< - HostCompute::template Op>, // activation(Z) * scaled_d - HEVT< - HostCompute, // activation(Z) - HEVT< - HostCompute, - HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta - HostAuxLoad, // C - HEVT< - HostCompute, - HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha - HostAccumulator<>, - HostColBroadcast> - > - > - >, - HostScalarBroadcast<1> // scale_d - > - >; -}; - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z -template class ActivationFn, class ElementD, class ElementAux = ElementD> -class HostScaledLinCombPerRowBiasEltActAmaxAux { -public: - using ElementC = typename Gemm::GemmKernel::ElementC; - using LayoutC = cutlass::detail::StrideToLayoutTagC_t; - using LayoutD = cutlass::detail::StrideToLayoutTagC_t; - - template - using amax = cutlass::maximum_absolute_value_reduction; - using EVTModuleAuxFp8 = HEVT< - HostAuxStore, - HST, - HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta - HostAuxLoad, // C - HEVT< - HostCompute, - HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha - HostAccumulator<>, - HostColBroadcast> - > - >, - // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) - HEVT< - HostCompute::template Op>, - HEVT< - HostScalarReduce, - HEVT< - HostCompute, //activation(Z) * scaled_d - HostAccumulator<> // Z - > - >, - HostScalarBroadcast<1> // scale_d - >, - // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) - HEVT< - HostAuxStore, - HEVT< - HostCompute, - HEVT< - HostScalarReduce, - HostAccumulator<> - >, - HostScalarBroadcast<1> - > - > - > - >; - - using EVTModuleAuxNotFp8 = HEVT< - // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) - HostAuxStore, - HEVT< - HostCompute::template Op>, - HEVT< - HostScalarReduce, - HEVT< - HostCompute, //activation(Z) * scaled_d - HEVT< - // Aux = Z - HostAuxStore, - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - HEVT< - HostCompute, - HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta - HostAuxLoad, // C - HEVT< - HostCompute, - HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha - HostAccumulator<>, - HostColBroadcast> - > - > - > - > - >, - HostScalarBroadcast<1> // scale_d - > - >; - - using EVTModule = cute::conditional_t, EVTModuleAuxFp8, EVTModuleAuxNotFp8>; - -}; -} // namespace test::gemm::device - -////////////////////////////////////////////////////////////////////////////// -namespace cutlass::epilogue { -namespace fusion { - -namespace detail { - -template -struct maximum_with_default_nan_propagation : maximum {}; - -} // namespace detail - -////////////////////////////////////////////////////////////////////////////// -/// D = alpha * acc + beta * C + AuxLoad -template< - class EpilogueDescriptor, - class AuxLoadDescriptor, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombAuxLoad = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha - Sm90AccFetch, // acc - Sm90AuxLoad< - AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, - typename AuxLoadDescriptor::Element, - typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, - typename AuxLoadDescriptor::CopyOpS2R // aux load - > - > - >; - -////////////////////////////////////////////////////////////////////////////// -/// D = alpha * acc + beta * C + AuxLoadNoSmem -template< - class EpilogueDescriptor, - class ElementAux, - class StrideAux, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombAuxLoadNoSmem = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha - Sm90AccFetch, // acc - Sm90AuxLoad<0, void, ElementAux, StrideAux, void, void> // aux load - > - >; - -////////////////////////////////////////////////////////////////////////////// -/// Example DAG -/// beta * C + Graph(alpha * acc + gamma + acc) -template< - typename EpilogueDescriptor, - typename AuxLoadDescriptor, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombEVTDAG = - Sm90EVT, // beta * C + (alpha * acc + aux) - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90TopologicalVisitor< - ElementCompute, - cute::tuple< - cute::seq<>, // 0. alpha - cute::seq<>, // 1. acc - cute::seq<>, // 2. aux load - cute::seq<1, 0, 2>, // 3. alpha * acc + aux load - cute::seq<3>, // relu(alpha & acc + aux load) - cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load - >, - Sm90ScalarBroadcast, // alpha - Sm90AccFetch, // acc - Sm90AuxLoad< - AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, - typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, - typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, - Sm90Compute, - Sm90Compute, - Sm90Compute - > - >; - - -////////////////////////////////////////////////////////////////////////////// -/// Example DAG -/// EVT = alpha * acc + C -/// D = Graph(maximum(EVT + per-row bias, EVT)) -template< - class EpilogueDescriptor, - class AuxStoreDescriptor, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombDAGEVT = - Sm90TopologicalVisitor< - ElementCompute, - cute::tuple< - cute::seq<>, - cute::seq<>, - cute::seq<1, 0>, - cute::seq<0, 2> - >, - Sm90EVT< - Sm90AuxStore< - AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, - typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, - typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, - Sm90EVT, - Sm90ScalarBroadcast, - Sm90AccFetch, - Sm90SrcFetch - > - >, - Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute>, - Sm90Compute, - Sm90Compute - >; - - -////////////////////////////////////////////////////////////////////////////// -/// D = alpha * acc + beta * C + per-column bias -template< - class EpilogueDescriptor, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColumnBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha - Sm90AccFetch, // acc - Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute> - > - >; - - -////////////////////////////////////////////////////////////////////////////// -/// D = per-column reduce(alpha * acc + beta * C) -template< - template class RegReduceFn, - template class GmemReduceFn, - class ElementReduce, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColumnReduce = - Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast, // alpha - Sm90AccFetch // acc - > - > - >; - - -////////////////////////////////////////////////////////////////////////////// -/// D = per-row reduce(alpha * acc + beta * C) -template< - template class RegReduceFn, - template class GmemReduceFn, - class ElementReduce, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowReduce = - Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast, // alpha - Sm90AccFetch // acc - > - > - >; - - -////////////////////////////////////////////////////////////////////////////// -/// D = scalar reduce(alpha * acc + beta * C) -template< - template class RegReduceFn, - template class GmemReduceFn, - class ElementReduce, - class ElementOutput, - class ElementCompute, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombScalarReduce = - Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc - Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast, // alpha - Sm90AccFetch // acc - > - > - >; -} // namespace fusion - -} // namespace cutlass::epilogue diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h deleted file mode 100644 index 0007666cdd084f35015200e36fd47f75971f6c1c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h +++ /dev/null @@ -1,639 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" - -#include "testbed_utils.h" -#include "testbed_universal.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Testbed { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - typename Gemm::LayoutA::Stride stride_factor_A; - typename Gemm::LayoutB::Stride stride_factor_B; - typename Gemm::LayoutC::Stride stride_factor_C; - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - Testbed( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - stride_factor_A(typename Gemm::LayoutA::Stride()), - stride_factor_B(typename Gemm::LayoutB::Stride()), - stride_factor_C(typename Gemm::LayoutC::Stride()), - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - Testbed( - typename Gemm::LayoutA::Stride stride_factor_A_, - typename Gemm::LayoutB::Stride stride_factor_B_, - typename Gemm::LayoutC::Stride stride_factor_C_, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - stride_factor_A(stride_factor_A_), - stride_factor_B(stride_factor_B_), - stride_factor_C(stride_factor_C_), - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 1; - scope_min = -1; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - - tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); - tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); - tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); - tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); - reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0) - << "tensor_D (size " << tensor_D.size() << ") has nonpositive norm"; - } - if (reference_D.size() > 1) { - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0) - << "reference_D (size " << reference_D.size() << ") has nonpositive norm"; - } - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); - - EXPECT_TRUE(passed) << "reference_D does not equal tensor_D"; - - if (!passed) { - - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view(); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - - cutlass::reference::host::Gemm< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, - ElementAccumulator, typename Gemm::Operator> - reference_gemm; - - reference_gemm( - problem_size, - alpha, - tensor_A.host_ref(), - tensor_B.host_ref(), - beta, - reference_D.host_ref(), - ElementAccumulator(0) - ); - - if (Relu) { - for (int i = 0; i < problem_size.m(); ++i) { - for (int j = 0; j < problem_size.n(); ++j) { - reference_D.at(cutlass::MatrixCoord(i, j)) = - ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) - ? (typename Gemm::ElementC)0 - : reference_D.at(cutlass::MatrixCoord(i, j)); - } - } - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Determine if the CUDA device is sufficient to run the kernel - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - int split_k_slices = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) - { -/* - std::cout << "\n-----------------------\n"; - std::cout << "problem size: " << problem_size << "\n"; - std::cout << "split_k_slices: " << split_k_slices << "\n"; - std::cout << "alpha: " << alpha << "\n"; - std::cout << "beta: " << beta << "\n"; - std::cout << "-----------------------\n\n"; -*/ - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), - {alpha, beta}, - split_k_slices - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) - << "gemm_op.initialize returned with error " << to_string(status) - << ", indicating that this test is not supported. Last CUDA error: " - << cudaGetErrorString(cudaGetLastError()); - if (status != cutlass::Status::kSuccess) { - return true; - } - - // - // Run the GEMM - // - - try { - status = gemm_op(); - } - catch (std::exception const& e) { - EXPECT_TRUE(false) << "gemm_op() threw a std::exception: " << e.what(); - throw; - } - catch (...) { - EXPECT_TRUE(false) << "gemm_op() threw an exception of unknown type"; - throw; - } - EXPECT_TRUE(status == cutlass::Status::kSuccess) - << "gemm_op failed with error " << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - EXPECT_TRUE(passed) << "Error: split_k_slices = " << split_k_slices - << ", alpha: " << alpha; - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllGemmBasic( - const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), - const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), - const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { - bool passed = true; - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - int const kAlignment = cutlass::platform::is_same< - typename Gemm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - (cutlass::platform::is_same::value || - cutlass::platform::is_same::value) ? 4 : kAlignment; - - int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; - - int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; - - int problem_size_k[] = { - kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; - - int split_k_slices[] = { - 1, 2, 3 - }; - - double problem_alpha[] = { - 1 - }; - - double problem_beta[] = { - 2.0 - }; - - Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int split_k : split_k_slices) { - - if (!Gemm::kSplitKSerial && split_k > 1) { - continue; - } - - if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { - continue; - } - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - cutlass::gemm::GemmCoord problem_size(m, n, k); - try { - passed = testbed.run( - problem_size, - split_k, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - catch (std::exception const& e) { - EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " - "exception {alpha: " << alpha << ", beta: " << beta << ", m: " - << m << ", n: " << n << ", k: " << k << "}: " << e.what(); - throw; - } - catch (...) { - EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " - "exception {alpha: " << alpha << ", beta: " << beta << ", m: " - << m << ", n: " << n << ", k: " << k << "}: (unknown)"; - throw; - } - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllGemm( - const typename Gemm::LayoutA::Stride& stride_factor_A, - const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), - const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) -{ - // Test basic GEMM with non-default stride factors - return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); -} - -template -bool TestAllGemm() -{ -#ifdef NDEBUG - // Non-debug builds also test basic GEMM with default stride factors - if (!TestAllGemmBasic()) { - return false; - } -#endif // NDEBUG - - // Test universal GEMM -#if 0 - // Define the universal kernel - using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< - typename Gemm::GemmKernel::Mma, // Mma - typename Gemm::GemmKernel::Epilogue, // Epilogue - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle - >; -#else - // Define the streamk universal kernel - using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< - typename Gemm::GemmKernel::Mma, // Mma - typename Gemm::GemmKernel::Epilogue, // Epilogue - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle - >; -#endif - - // Define the universal adaptor - using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; - - // Test universal GEMM - return TestAllGemmUniversal(); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestGemmPerf(int iterations = 1) { - bool passed = true; - - int problem_size_m[] = { 2048 }; - - int problem_size_n[] = { 4352 }; - - int problem_size_k[] = { 4096 }; - - int split_k_slices[] = { 1 }; - double problem_alpha[] = { 1 }; - double problem_beta[] = { 0.0 }; - - Testbed testbed; - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int split_k : split_k_slices) { - - if (!Gemm::kSplitKSerial && split_k > 1) { - continue; - } - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - cutlass::gemm::GemmCoord problem_size(m, n, k); - - for (int i = 0; i < iterations; i++){ - try { - passed = testbed.run( - problem_size, - split_k, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - } - catch (std::exception const& e) { - EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " - "exception {alpha: " << alpha << ", beta: " << beta << ", m: " - << m << ", n: " << n << ", k: " << k << "}: " << e.what(); - throw; - } - catch (...) { - EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " - "exception {alpha: " << alpha << ", beta: " << beta << ", m: " - << m << ", n: " << n << ", k: " << k << "}: (unknown)"; - throw; - } - } - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h deleted file mode 100644 index add984ca3b9a0c05325b93cf52cbadd710527ba6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h +++ /dev/null @@ -1,294 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm_complex.h" - -#include "testbed.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedComplex : public Testbed { - - using Base = Testbed; - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - - // - // Methods - // - - TestbedComplex( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - Base(init_A_, init_B_, init_C_, seed_) { } - - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - - cutlass::reference::host::GemmComplex( - problem_size, - alpha, - this->tensor_A.host_ref(), - Gemm::kTransformA, - this->tensor_B.host_ref(), - Gemm::kTransformB, - beta, - this->tensor_C.host_ref(), - this->reference_D.host_ref(), - ElementAccumulator(0) - ); - - return this->compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - int split_k_slices = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - // - // Initialize workspace - // - - this->initialize(problem_size); - - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, - this->tensor_A.device_ref(), - this->tensor_B.device_ref(), - this->tensor_C.device_ref(), - this->tensor_D.device_ref(), - {alpha, beta}, - split_k_slices - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllGemmComplex() { - bool passed = true; - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - int const kAlignment = - cutlass::platform::is_same< - typename Gemm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - int problem_size_m[] = { - kAlignment, 512 - 3*kAlignment - }; - - int problem_size_n[] = { - kAlignment, 512 - 2*kAlignment - }; - - int problem_size_k[] = { - kAlignment, 128 - kAlignment - }; - - int split_k_slices[] = { - 1, 2, 3 - }; - - double problem_alpha[] = { - 1 - }; - - double problem_beta[] = { - 2.0 - }; - - TestbedComplex testbed; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int split_k : split_k_slices) { - - if (!Gemm::kSplitKSerial && split_k > 1) { - continue; - } - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - cutlass::gemm::GemmCoord problem_size(m, n, k); - - passed = testbed.run( - problem_size, - split_k, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h deleted file mode 100644 index eca0b0ae0decf3293f6f73cb6ebbc5b5735a8e49..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ /dev/null @@ -1,670 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/gemm_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmWithBroadcastReferenceOp { - - using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - - using ElementCompute = typename OutputOp::ElementCompute; - using ElementZ = typename OutputOp::ElementZ; - using ElementT = typename OutputOp::ElementT; - - typename OutputOp::BinaryOp binary_op; - typename OutputOp::ElementwiseOp elementwise_op; - - GemmWithBroadcastReferenceOp() { } - - void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { - - ElementCompute t_full = binary_op(gemm, bias); - - if (OutputOp::kStoreT) { - T = ElementT(t_full); - } - - if (OutputOp::kStoreZ) { - ElementCompute z_full = elementwise_op(t_full); - Z = ElementZ(z_full); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Fused testbed -// -// Y = GEMM(AB, C) -// -// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) -// -// Z[i, j] = Elementwise(T[i, j]) -// - -template < - typename Gemm, - typename ReferenceOp = GemmWithBroadcastReferenceOp -> -struct TestbedGemmWithBroadcast { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename OutputOp::ElementCompute; - using ElementVector = typename OutputOp::ElementVector; - using ElementZ = typename OutputOp::ElementZ; - using ElementT = typename OutputOp::ElementT; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; // Input A - cutlass::HostTensor tensor_B; // Input B - cutlass::HostTensor tensor_C; // Input C - cutlass::HostTensor tensor_Broadcast; // Input Broadcast - - cutlass::HostTensor tensor_Z; - cutlass::HostTensor tensor_T; - - cutlass::HostTensor tensor_C_ref; - cutlass::HostTensor tensor_Y_ref; - cutlass::HostTensor tensor_Z_ref; - cutlass::HostTensor tensor_T_ref; - - - // - // Methods - // - - TestbedGemmWithBroadcast( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 1; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - - tensor_A.resize(problem_size.mk()); - tensor_B.resize(problem_size.kn()); - tensor_C.resize(problem_size.mn()); - tensor_Z.resize(problem_size.mn()); - tensor_T.resize(problem_size.mn()); - tensor_Broadcast.resize({ - problem_size.m(), - 1 - }); - - tensor_C_ref.resize(problem_size.mn()); - tensor_Y_ref.resize(problem_size.mn()); - tensor_Z_ref.resize(problem_size.mn()); - tensor_T_ref.resize(problem_size.mn()); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); - - for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { - for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { - tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); - } - } - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_Broadcast.sync_device(); - - tensor_Z.sync_device(); - tensor_T.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementAccumulator alpha, - ElementAccumulator beta) { - - tensor_Z.sync_host(); - tensor_T.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (OutputOp::kStoreZ) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); - } - - if (OutputOp::kStoreT) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); - } - - bool passed = true; - float norm_diff = 0; - - if (OutputOp::kStoreZ) { - norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); - passed = (norm_diff <= 0.1f); - EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; - } - - if (OutputOp::kStoreT) { - - norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); - passed = (passed && (norm_diff <= 0.1f)); - - EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; - } - - - if (!passed) { - - /* - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - */ - - std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); - - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\nZ =\n" << tensor_Z.host_view() - << "\nT =\n" << tensor_T.host_view() - << "\n\n" - << "\nY_ref =\n" << tensor_Y_ref.host_view() - << "\nZ_ref =\n" << tensor_Z_ref.host_view() - << "\nT_ref =\n" << tensor_T_ref.host_view(); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementAccumulator alpha, - ElementAccumulator beta) { - - // - // Verify - // - - cutlass::reference::host::GemmComplex< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - ElementAccumulator, typename Gemm::LayoutC, - ElementAccumulator, ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - Gemm::kTransformA, - tensor_B.host_ref(), - Gemm::kTransformB, - beta, - tensor_C_ref.host_ref(), - tensor_Y_ref.host_ref(), - ElementAccumulator(0) - ); - - using ElementC = typename Gemm::ElementC; - - ReferenceOp reference_op; - - // compute tensor Z and tensor T - for (int m = 0; m < problem_size.m(); ++m) { - for (int n = 0; n < problem_size.n(); ++n) { - - ElementZ z; - ElementT t; - - reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); - - if (OutputOp::kStoreZ) { - tensor_Z_ref.at({m, n}) = z; - } - - if (OutputOp::kStoreT) { - tensor_T_ref.at({m, n}) = t; - } - } - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementAccumulator alpha = ElementAccumulator(1), - ElementAccumulator beta = ElementAccumulator(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_Z.device_data(), - tensor_Broadcast.device_data(), - tensor_T.device_data(), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - problem_size.m(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_Z.layout().stride(0), - 0, // This must be zero - tensor_T.layout().stride(0), - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = true; - - passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; - } - - // - // Profile - // - - #if 0 // profiling disabled for now. - - int const kWorkspaces = 100; - - cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); - - cudaEvent_t events[2]; - for (auto & event : events) { - cudaError_t result = cudaEventCreate(&event); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); - return false; - break; - } - } - - int const kWarmupIterations = 5; - int const kProfilingIterations = 100; - - for (int i = 0; i < kWarmupIterations; ++i) { - status = gemm_op(); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - } - - - cudaError_t result = cudaEventRecord(events[0]); - EXPECT_EQ(result, cudaSuccess); - - for (int i = 0; i < kProfilingIterations; ++i) { - - typename Gemm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), - profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), - profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), - profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), - profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), - profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - problem_size.m(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_Z.layout().stride(0), - 0, // This must be zero - tensor_T.layout().stride(0), - }; - - gemm_op.initialize(arguments, workspace.get()); - status = gemm_op(); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - } - - result = cudaEventRecord(events[1]); - EXPECT_EQ(result, cudaSuccess); - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess); - - float elapsed_time = 0; - result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); - EXPECT_EQ(result, cudaSuccess); - - double average_time = double(elapsed_time) / double(kProfilingIterations); - - std::cout << problem_size << ": " << average_time << " ms" << std::endl; - - for (auto & event : events) { - cudaEventDestroy(event); - } - #endif - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - typename ReferenceOp = GemmWithBroadcastReferenceOp -> -bool TestGemmWithBroadcast( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedGemmWithBroadcast testbed; - - using ElementAccumulator = typename Gemm::ElementAccumulator; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - typename ReferenceOp = GemmWithBroadcastReferenceOp -> -bool TestAllGemmWithBroadcast() { - - int M_problems[] = {8, 136, 264, 520}; - int N_problems[] = {8, 136, 264, 520}; - int K_problems[] = {8, 136, 264, 520}; - double alpha_problems[] = {1.25, 2.25}; - double beta_problems[] = {0, 1, 2.0}; - - bool passed = true; - - for (int M : M_problems) { - for (int N : N_problems) { - for (int K : K_problems) { - for (double alpha : alpha_problems) { - for (double beta : beta_problems) { - - TestbedGemmWithBroadcast testbed; - - using ElementAccumulator = typename Gemm::ElementAccumulator; - - passed = testbed.run( - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - 1, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - EXPECT_TRUE(passed) - << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; - - if (!passed) { - - return passed; - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h deleted file mode 100644 index af3629ccfb87e09e80b85af508379780d6428dc5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h +++ /dev/null @@ -1,588 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/gemm_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmWithReductionReference { - - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; - using ElementC = typename Gemm::ElementC; - using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; - // - // Data members - // - - BinaryOp binary_op; - - // - // Methods - // - - GemmWithReductionReference() { } - - ElementCompute operator()( - ElementAccumulator d_y, - ElementT t) { - - return binary_op(ElementCompute(d_y), ElementCompute(t)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - typename ReferenceOp -> -struct TestbedGemmWithReduction { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor tensor_Reduction; - cutlass::HostTensor tensor_Tensor; - cutlass::HostTensor tensor_C_ref; - cutlass::HostTensor reference_d_Y; - cutlass::HostTensor reference_D; - cutlass::HostTensor reference_Reduction; - - // - // Methods - // - - TestbedGemmWithReduction( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 1; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - for (int m = 0; m < view.extent().row(); ++m) { - for (int n = 0; n < view.extent().column(); ++n) { - //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); - view.at({m, n}) = (n == 0 ? Element(m) : Element()); - - } - } - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - - tensor_A.resize(problem_size.mk()); - tensor_B.resize(problem_size.kn()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - - tensor_Reduction.resize({ - problem_size.m(), - (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN - }); - - tensor_Tensor.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - reference_d_Y.resize(problem_size.mn(), false); - tensor_C_ref.resize(problem_size.mn(), false); - reference_Reduction.resize({problem_size.m(), 1}, false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); - - for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { - for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { - tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); - } - } - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - tensor_Reduction.sync_device(); - tensor_Tensor.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementAccumulator alpha, - ElementAccumulator beta) { - - tensor_Reduction.sync_host(); - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); - - bool passed = true; - for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { - - ElementAccumulator reduced_value = ElementAccumulator(); - for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { - reduced_value += tensor_Reduction.at({m, j}); - } - - if (reduced_value != reference_Reduction.at({m, 0})) { - std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; - passed = false; - break; - } - } - EXPECT_TRUE(passed) << "Reduction is incorect."; - - if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { - EXPECT_TRUE(false) << " mismatched reference"; - passed = false; - } - - if (!passed) { - - /* - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - */ - - std::ofstream file("testbed_universal_errors_sm70.txt"); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\nT = \n" << tensor_Tensor.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view() - << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" - << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementAccumulator alpha, - ElementAccumulator beta) { - - // - // Verify - // - - cutlass::reference::host::GemmComplex< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - ElementAccumulator, typename Gemm::LayoutC, - ElementAccumulator, ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - Gemm::kTransformA, - tensor_B.host_ref(), - Gemm::kTransformB, - beta, - tensor_C_ref.host_ref(), - reference_d_Y.host_ref(), - ElementAccumulator(0) - ); - - using ElementC = typename Gemm::ElementC; - - ReferenceOp reference_op; - - // compute backwards - for (int m = 0; m < problem_size.m(); ++m) { - ElementAccumulator reduced_value = ElementAccumulator(); - for (int n = 0; n < problem_size.n(); ++n) { - ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); - reduced_value += d_full; - reference_D.at({m, n}) = ElementC(d_full); - } - reference_Reduction.at({m, 0}) = reduced_value; - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementAccumulator alpha = ElementAccumulator(1), - ElementAccumulator beta = ElementAccumulator(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - tensor_Reduction.device_data(), - tensor_Tensor.device_data(), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - problem_size.m(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0), - tensor_Reduction.layout().stride(0), - tensor_Tensor.layout().stride(0), - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; - } - - // - // Profile - // - - #if 0 // profiling disabled for now. - - int const kWorkspaces = 100; - - cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); - cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); - - cudaEvent_t events[2]; - for (auto & event : events) { - cudaError_t result = cudaEventCreate(&event); - if (result != cudaSuccess) { - EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); - return false; - break; - } - } - - int const kWarmupIterations = 5; - int const kProfilingIterations = 100; - - for (int i = 0; i < kWarmupIterations; ++i) { - status = gemm_op(); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - } - - - cudaError_t result = cudaEventRecord(events[0]); - EXPECT_EQ(result, cudaSuccess); - - for (int i = 0; i < kProfilingIterations; ++i) { - - typename Gemm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), - profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), - profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), - profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), - profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), - profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - problem_size.m(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0), - tensor_Reduction.layout().stride(0), - tensor_Tensor.layout().stride(0), - }; - - gemm_op.initialize(arguments, workspace.get()); - status = gemm_op(); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - } - - result = cudaEventRecord(events[1]); - EXPECT_EQ(result, cudaSuccess); - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess); - - float elapsed_time = 0; - result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); - EXPECT_EQ(result, cudaSuccess); - - double average_time = double(elapsed_time) / double(kProfilingIterations); - - std::cout << problem_size << ": " << average_time << " ms" << std::endl; - - for (auto & event : events) { - cudaEventDestroy(event); - } - #endif - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestGemmWithReduction( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count = 1, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedGemmWithReduction testbed; - - using ElementAccumulator = typename Gemm::ElementAccumulator; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h deleted file mode 100644 index c7317eb855477e63fe19858ca51cd5722f236eb5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h +++ /dev/null @@ -1,501 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface - -*/ - -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" -#include "cutlass/gemm/device/gemm_grouped.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm_complex.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/tensor_view_io.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedGrouped { - - // - // Type definitions - // - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using LayoutA = typename Gemm::LayoutA; - using LayoutB = typename Gemm::LayoutB; - using LayoutC = typename Gemm::LayoutC; - - using MatrixCoord = typename LayoutC::TensorCoord; - - // - // Data members - // - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint32_t seed; - - int problem_count; - - std::vector problem_sizes_host; - cutlass::DeviceAllocation problem_sizes_device; - - std::vector offset_A; - std::vector offset_B; - std::vector offset_C; - std::vector offset_D; - - std::vector lda_host; - std::vector ldb_host; - std::vector ldc_host; - std::vector ldd_host; - - cutlass::DeviceAllocation lda; - cutlass::DeviceAllocation ldb; - cutlass::DeviceAllocation ldc; - cutlass::DeviceAllocation ldd; - - cutlass::DeviceAllocation block_A; - cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - - cutlass::DeviceAllocation ptr_A; - cutlass::DeviceAllocation ptr_B; - cutlass::DeviceAllocation ptr_C; - cutlass::DeviceAllocation ptr_D; - - // - // Methods - // - - TestbedGrouped( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint32_t seed_ = 3080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint32_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope_max = 5; - scope_min = -5; - } - else { - scope_max = 8; - scope_min = -8; - } - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - // no fill - remain zero - } - - return true; - } - - /// Initializes data structures - void initialize() { - - // - // Choose random problem sizes - // - - // construct a few problems of random sizes - srand(seed); - - int64_t total_elements_A = 0; - int64_t total_elements_B = 0; - int64_t total_elements_C = 0; - int64_t total_elements_D = 0; - - - lda_host.resize(problem_count); - ldb_host.resize(problem_count); - ldc_host.resize(problem_count); - ldd_host.resize(problem_count); - - problem_sizes_host.clear(); - problem_sizes_host.resize(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - - cutlass::gemm::GemmCoord problem( - 8 * (rand() % 64) + 24, - 8 * (rand() % 64) + 24, - 8 * (rand() % 64) + 24); - - if (!i) { - problem = cutlass::gemm::GemmCoord(48, 16, 8); - } - - problem_sizes_host.at(i) = problem; - - // std::cout << "Problem[" << i << "]: " << problem << std::endl; - - lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); - ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); - ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); - ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); - - offset_A.push_back(total_elements_A); - offset_B.push_back(total_elements_B); - offset_C.push_back(total_elements_C); - offset_D.push_back(total_elements_D); - - int64_t elements_A = problem.m() * problem.k(); - int64_t elements_B = problem.k() * problem.n(); - int64_t elements_C = problem.m() * problem.n(); - int64_t elements_D = problem.m() * problem.n(); - - total_elements_A += elements_A; - total_elements_B += elements_B; - total_elements_C += elements_C; - total_elements_D += elements_D; - - // Random strides between problems? - } - - problem_sizes_device.reset(problem_count); - problem_sizes_device.copy_from_host(problem_sizes_host.data()); - - lda.reset(problem_count); - ldb.reset(problem_count); - ldc.reset(problem_count); - ldd.reset(problem_count); - - lda.copy_from_host(lda_host.data()); - ldb.copy_from_host(ldb_host.data()); - ldc.copy_from_host(ldc_host.data()); - ldd.copy_from_host(ldd_host.data()); - - // - // Assign pointers - // - - block_A.reset(total_elements_A); - block_B.reset(total_elements_B); - block_C.reset(total_elements_C); - block_D.reset(total_elements_D); - - std::vector ptr_A_host(problem_count); - std::vector ptr_B_host(problem_count); - std::vector ptr_C_host(problem_count); - std::vector ptr_D_host(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - } - - ptr_A.reset(problem_count); - ptr_A.copy_from_host(ptr_A_host.data()); - - ptr_B.reset(problem_count); - ptr_B.copy_from_host(ptr_B_host.data()); - - ptr_C.reset(problem_count); - ptr_C.copy_from_host(ptr_C_host.data()); - - ptr_D.reset(problem_count); - ptr_D.copy_from_host(ptr_D_host.data()); - - // - // Initialize the problems of the workspace - // - - for (int32_t i = 0; i < problem_count; ++i) { - cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); - - LayoutA layout_A(lda_host.at(i)); - LayoutB layout_B(ldb_host.at(i)); - LayoutC layout_C(ldc_host.at(i)); - LayoutC layout_D(ldd_host.at(i)); - - MatrixCoord extent_A{problem.m(), problem.k()}; - MatrixCoord extent_B{problem.k(), problem.n()}; - MatrixCoord extent_C{problem.m(), problem.n()}; - - std::vector matrix_A(layout_A.capacity(extent_A)); - std::vector matrix_B(layout_B.capacity(extent_B)); - std::vector matrix_C(layout_C.capacity(extent_C)); - std::vector matrix_D(layout_D.capacity(extent_C)); - - initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); - initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); - initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); - - cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); - cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); - cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); - cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); - } - } - - /// Verifies the result is a GEMM - bool verify( - ElementCompute alpha, - ElementCompute beta) { - - bool passed = true; - - for (int32_t i = 0; i < problem_count; ++i) { - cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); - - LayoutA layout_A(lda_host.at(i)); - LayoutB layout_B(ldb_host.at(i)); - LayoutC layout_C(ldc_host.at(i)); - LayoutC layout_D(ldd_host.at(i)); - - MatrixCoord extent_A{problem.m(), problem.k()}; - MatrixCoord extent_B{problem.k(), problem.n()}; - MatrixCoord extent_C{problem.m(), problem.n()}; - - std::vector matrix_A(layout_A.capacity(extent_A)); - std::vector matrix_B(layout_B.capacity(extent_B)); - std::vector matrix_C(layout_C.capacity(extent_C)); - std::vector matrix_D(layout_D.capacity(extent_C)); - std::vector matrix_Ref(layout_D.capacity(extent_C)); - - cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); - cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); - cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); - cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); - - cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); - cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); - cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); - cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); - cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); - - // Reference GEMM - cutlass::reference::host::GemmComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, ElementAccumulator - >( - problem, - alpha, - view_A, - Gemm::kTransformA, - view_B, - Gemm::kTransformB, - beta, - view_C, - view_Ref, - ElementAccumulator(0) - ); - - // Ensure that no input or output is entirely zero - EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); - - // Compare against reference - passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); - - if (!passed) { - std::ofstream file("testbed_grouped_errors.txt"); - - file - << "problem: " << problem << " [group: " << i << "]\n" - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << view_A - << "\nB =\n" << view_B - << "\nC =\n" << view_C - << "\n\nReference =\n" << view_Ref - << "\nComputed =\n" << view_D; - - return passed; - } - } - - return passed; - } - - /// Executes one test - bool run( - int problem_count, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - this->problem_count = problem_count; - - // Initialize the problem - initialize(); - - int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); - - // Early exit - if (!threadblock_count) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; - } - return true; - } - - // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(alpha, beta); - - // Configure GEMM arguments - typename Gemm::Arguments args( - problem_sizes_device.get(), - problem_count, - threadblock_count, - epilogue_op, - ptr_A.get(), - ptr_B.get(), - ptr_C.get(), - ptr_D.get(), - lda.get(), - ldb.get(), - ldc.get(), - ldd.get(), - problem_sizes_host.data() - ); - - // Initialize the GEMM object - Gemm gemm; - - size_t workspace_size = gemm.get_workspace_size(args); - cutlass::DeviceAllocation workspace(workspace_size); - - cutlass::Status status = gemm.initialize(args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - return false; - } - - // Run the GEMM object - status = gemm.run(); - - if (status != cutlass::Status::kSuccess) { - return false; - } - - // Wait for completion - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) - << "Kernel execution error: " << cudaGetErrorString(result); - - if (result != cudaSuccess) { - return false; - } - - // Verify correctness - return verify(alpha, beta); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // device -} // gemm -} // test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h deleted file mode 100644 index f8f08f23c4477745648f1cf8f9e439ae6b5061e2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h +++ /dev/null @@ -1,502 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for grouped Rank2K interface - -*/ - -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/rank_2k_grouped.h" -#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -#include "cutlass/gemm/device/rank_2k_grouped.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/rank_2k_complex.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/tensor_view_io.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedGrouped { - - // - // Type definitions - // - - using ElementA = typename Rank2K::ElementA; - using ElementB = typename Rank2K::ElementB; - using ElementC = typename Rank2K::ElementC; - using ElementAccumulator = typename Rank2K::ElementAccumulator; - - using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; - - using LayoutA = typename Rank2K::LayoutA; - using LayoutB = typename Rank2K::LayoutB; - using LayoutC = typename Rank2K::LayoutC; - - using MatrixCoord = typename LayoutC::TensorCoord; - - // - // Data members - // - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint32_t seed; - - int problem_count; - - std::vector problem_sizes_host; - cutlass::DeviceAllocation problem_sizes_device; - - std::vector offset_A; - std::vector offset_B; - std::vector offset_C; - std::vector offset_D; - - std::vector lda_host; - std::vector ldb_host; - std::vector ldc_host; - std::vector ldd_host; - - cutlass::DeviceAllocation lda; - cutlass::DeviceAllocation ldb; - cutlass::DeviceAllocation ldc; - cutlass::DeviceAllocation ldd; - - cutlass::DeviceAllocation block_A; - cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - - cutlass::DeviceAllocation ptr_A; - cutlass::DeviceAllocation ptr_B; - cutlass::DeviceAllocation ptr_C; - cutlass::DeviceAllocation ptr_D; - - // - // Methods - // - - TestbedGrouped( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint32_t seed_ = 3080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint32_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - if (cutlass::sizeof_bits::value <= 16) { - scope_max = 5; - scope_min = -5; - } - else { - scope_max = 8; - scope_min = -8; - } - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - // no fill - remain zero - } - - return true; - } - - /// Initializes data structures - void initialize() { - - // - // Choose random problem sizes - // - - // construct a few problems of random sizes - srand(seed); - - int64_t total_elements_A = 0; - int64_t total_elements_B = 0; - int64_t total_elements_C = 0; - int64_t total_elements_D = 0; - - - lda_host.resize(problem_count); - ldb_host.resize(problem_count); - ldc_host.resize(problem_count); - ldd_host.resize(problem_count); - - problem_sizes_host.clear(); - problem_sizes_host.resize(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - - auto N = 8 * (rand() % 64) + 24; - auto K = 8 * (rand() % 64) + 24; - cutlass::gemm::GemmCoord problem(N, N, K); - - if (!i) { - problem = cutlass::gemm::GemmCoord(16, 16, 8); - } - - problem_sizes_host.at(i) = problem; - - lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); - ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); - ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); - ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); - - offset_A.push_back(total_elements_A); - offset_B.push_back(total_elements_B); - offset_C.push_back(total_elements_C); - offset_D.push_back(total_elements_D); - - int64_t elements_A = problem.n() * problem.k(); - int64_t elements_B = problem.n() * problem.k(); - int64_t elements_C = problem.n() * problem.n(); - int64_t elements_D = problem.n() * problem.n(); - - total_elements_A += elements_A; - total_elements_B += elements_B; - total_elements_C += elements_C; - total_elements_D += elements_D; - - // Random strides between problems? - } - - problem_sizes_device.reset(problem_count); - problem_sizes_device.copy_from_host(problem_sizes_host.data()); - - lda.reset(problem_count); - ldb.reset(problem_count); - ldc.reset(problem_count); - ldd.reset(problem_count); - - lda.copy_from_host(lda_host.data()); - ldb.copy_from_host(ldb_host.data()); - ldc.copy_from_host(ldc_host.data()); - ldd.copy_from_host(ldd_host.data()); - - // - // Assign pointers - // - - block_A.reset(total_elements_A); - block_B.reset(total_elements_B); - block_C.reset(total_elements_C); - block_D.reset(total_elements_D); - - std::vector ptr_A_host(problem_count); - std::vector ptr_B_host(problem_count); - std::vector ptr_C_host(problem_count); - std::vector ptr_D_host(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - } - - ptr_A.reset(problem_count); - ptr_A.copy_from_host(ptr_A_host.data()); - - ptr_B.reset(problem_count); - ptr_B.copy_from_host(ptr_B_host.data()); - - ptr_C.reset(problem_count); - ptr_C.copy_from_host(ptr_C_host.data()); - - ptr_D.reset(problem_count); - ptr_D.copy_from_host(ptr_D_host.data()); - - // - // Initialize the problems of the workspace - // - - for (int32_t i = 0; i < problem_count; ++i) { - cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); - - LayoutA layout_A(lda_host.at(i)); - LayoutB layout_B(ldb_host.at(i)); - LayoutC layout_C(ldc_host.at(i)); - LayoutC layout_D(ldd_host.at(i)); - - MatrixCoord extent_A{problem.n(), problem.k()}; - MatrixCoord extent_B{problem.n(), problem.k()}; - MatrixCoord extent_C{problem.n(), problem.n()}; - - std::vector matrix_A(layout_A.capacity(extent_A)); - std::vector matrix_B(layout_B.capacity(extent_B)); - std::vector matrix_C(layout_C.capacity(extent_C)); - std::vector matrix_D(layout_D.capacity(extent_C)); - - initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); - initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); - initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); - - cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); - cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); - cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); - cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); - } - } - - /// Verifies the result is a Rank2K - bool verify( - ElementCompute alpha, - ElementCompute beta) { - - bool passed = true; - - for (int32_t i = 0; i < problem_count; ++i) { - cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); - - LayoutA layout_A(lda_host.at(i)); - LayoutB layout_B(ldb_host.at(i)); - LayoutC layout_C(ldc_host.at(i)); - LayoutC layout_D(ldd_host.at(i)); - - MatrixCoord extent_A{problem.n(), problem.k()}; - MatrixCoord extent_B{problem.n(), problem.k()}; - MatrixCoord extent_C{problem.n(), problem.n()}; - - std::vector matrix_A(layout_A.capacity(extent_A)); - std::vector matrix_B(layout_B.capacity(extent_B)); - std::vector matrix_C(layout_C.capacity(extent_C)); - std::vector matrix_D(layout_D.capacity(extent_C)); - std::vector matrix_Ref(layout_D.capacity(extent_C)); - - cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); - cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); - cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); - cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); - - cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); - cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); - cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); - cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); - cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); - - // Reference Rank2K - cutlass::reference::host::Rank2KComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, ElementAccumulator - >( - problem, - alpha, - view_A, - Rank2K::kTransformA, - view_B, - Rank2K::kTransformB, - beta, - view_C, - view_Ref, - ElementAccumulator(0), - Rank2K::kFillModeC, - Rank2K::kBlasMode - ); - - // Ensure that no input or output is entirely zero - EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); - - // Compare against reference - passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); - - if (!passed) { - std::ofstream file("testbed_grouped_errors.txt"); - - file - << "problem: " << problem << " [group: " << i << "]\n" - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << view_A - << "\nB =\n" << view_B - << "\nC =\n" << view_C - << "\n\nReference =\n" << view_Ref - << "\nComputed =\n" << view_D; - - return passed; - } - } - - return passed; - } - - /// Executes one test - bool run( - int problem_count, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - this->problem_count = problem_count; - - // Initialize the problem - initialize(); - - int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); - - // Early exit - if (!threadblock_count) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; - } - return true; - } - - // Configure the Rank2K arguments - typename EpilogueOutputOp::Params epilogue_op(alpha, beta); - - // Configure Rank2K arguments - typename Rank2K::Arguments args( - cutlass::gemm::GemmUniversalMode::kGemm, - problem_sizes_device.get(), - problem_count, - threadblock_count, - epilogue_op, - ptr_A.get(), - ptr_B.get(), - ptr_C.get(), - ptr_D.get(), - lda.get(), - ldb.get(), - ldc.get(), - ldd.get(), - problem_sizes_host.data() - ); - - // Initialize the Rank2K object - Rank2K rank2k; - - size_t workspace_size = rank2k.get_workspace_size(args); - cutlass::DeviceAllocation workspace(workspace_size); - - cutlass::Status status = rank2k.initialize(args, workspace.get()); - - if (status != cutlass::Status::kSuccess) { - return false; - } - - // Run the Rank2K object - status = rank2k.run(); - - if (status != cutlass::Status::kSuccess) { - return false; - } - - // Wait for completion - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) - << "Kernel execution error: " << cudaGetErrorString(result); - - if (result != cudaSuccess) { - return false; - } - - // Verify correctness - return verify(alpha, beta); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // device -} // gemm -} // test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h deleted file mode 100644 index e9315e12e8711f50256e4cfe05666201acd614d3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h +++ /dev/null @@ -1,461 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for grouped Rank2K problem visitors -*/ - -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/device_kernel.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Use simple problem visitor as a baseline -template -struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { - using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - static cutlass::FillMode const kFillModeC = FillModeC; - - struct SharedStorage {}; - - int32_t tile_count_sum; - SharedStorage &shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - BaselineProblemVisitor( - Params const ¶ms_, - SharedStorage &shared_storage_, - int32_t block_idx - ): Base(params_, block_idx), - shared_storage(shared_storage_) - { - cutlass::gemm::GemmCoord problem = this->problem_size(); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - tile_count_sum = this->tile_count(grid); - } - - CUTLASS_DEVICE - bool next_tile() { - if (this->tile_idx < tile_count_sum) { - return true; - } - - do { - ++this->problem_idx; - - if (this->problem_idx >= this->params.problem_count) { - return false; - } - - cutlass::gemm::GemmCoord problem = this->problem_size(); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - - this->problem_tile_start = tile_count_sum; - tile_count_sum += this->tile_count(grid); - - } while (tile_count_sum <= this->tile_idx); - - return true; - } - - static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count) { - return 0; - } - - static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count, - void* host_workspace_ptr) {} - - CUTLASS_DEVICE - cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { - int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; - int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; - int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); - - if (FillModeC == cutlass::FillMode::kUpper) { - cutlass::swap(macro_row, macro_col); - } - - int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); - int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); - - return cutlass::gemm::GemmCoord(row, col, 0); - } -}; - -template -struct ProblemVisitorKernel { - struct SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - }; - - struct Params { - typename ProblemVisitor::Params problem_visitor_params; - int32_t* visited_problems_ptr; - int32_t* visited_tiles_ptr; - int32_t visits_per_block; - - Params(): - visited_problems_ptr(nullptr), - visited_tiles_ptr(nullptr), - visits_per_block(0) {} - - Params(typename ProblemVisitor::Params problem_visitor_params_, - int32_t* visited_problems_ptr_, - int32_t* visited_tiles_ptr_, - int32_t visits_per_block_): - problem_visitor_params(problem_visitor_params_), - visited_problems_ptr(visited_problems_ptr_), - visited_tiles_ptr(visited_tiles_ptr_), - visits_per_block(visits_per_block_) {} - }; - - CUTLASS_DEVICE - void operator()(const Params& params, SharedStorage &shared_storage) { - int32_t store_offset = params.visits_per_block * blockIdx.x; - ProblemVisitor problem_visitor(params.problem_visitor_params, - shared_storage.problem_visitor, - blockIdx.x); - - while (problem_visitor.next_tile()) { - cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); - - problem_visitor.advance(gridDim.x); - - // - // Early exit conditions - // 1) Out of range - // 2) Upper-triangular block in lower-triangular problem - // 3) Lower-triangular block in upper-triangular problem - // - - if (grid_shape.m() <= tile_offset.m() || - grid_shape.n() <= tile_offset.n()) { - continue; - } - - if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && - (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { - continue; - } - - if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && - tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { - continue; - } - - if (threadIdx.x == 0) { - params.visited_problems_ptr[store_offset] = problem_idx; - params.visited_tiles_ptr[store_offset] = threadblock_idx; - ++store_offset; - } - } - } -}; - -template -struct ProblemVisitorRunner { - using BaseKernel = ProblemVisitorKernel; - using Params = typename BaseKernel::Params; - - Params params; - std::vector host_problem_sizes; - int32_t problem_count; - int32_t threadblock_count; - int32_t visits_per_block; - cutlass::DeviceAllocation visited_problems; - cutlass::DeviceAllocation visited_tiles; - cutlass::DeviceAllocation device_problem_sizes; - cutlass::DeviceAllocation workspace; - std::vector host_visited_problems; - std::vector host_visited_tiles; - - ProblemVisitorRunner(const std::vector& host_problem_sizes_, - int32_t threadblock_count_): - host_problem_sizes(host_problem_sizes_), - problem_count(int32_t(host_problem_sizes_.size())), - threadblock_count(threadblock_count_) {} - - /// Initializes GEMM state from arguments. - cutlass::Status initialize() { - size_t workspace_bytes = ProblemVisitor::get_workspace_size( - host_problem_sizes.data(), - problem_count, - threadblock_count); - - workspace.reset(workspace_bytes); - std::vector host_workspace(workspace_bytes); - - int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); - - ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, - threadblock_count, host_workspace.data()); - - workspace.copy_from_host(host_workspace.data(), workspace_bytes); - - device_problem_sizes.reset(problem_count); - device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); - - visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; - int32_t total_visits = visits_per_block * threadblock_count; - - visited_problems.reset(total_visits); - visited_tiles.reset(total_visits); - host_visited_problems.resize(total_visits); - host_visited_tiles.resize(total_visits); - - cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } - - result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } - - typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); - params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); - - return cutlass::Status::kSuccess; - } - - bool verify() { - // Sort by problem size and then by threadblock_idx - std::vector indices(host_visited_problems.size()); - std::iota(indices.begin(), indices.end(), 0); - - std::stable_sort(indices.begin(), indices.end(), - [&](int32_t i1, int32_t i2) { - if (host_visited_problems[i1] == host_visited_problems[i2]) { - return host_visited_tiles[i1] < host_visited_tiles[i2]; - } - return host_visited_problems[i1] < host_visited_problems[i2]; - }); - - int32_t idx = 0; - - // Skip any entries that were not visited - while (host_visited_problems[indices[idx]] == -1) { - ++idx; - } - - // Check that each problem visited has the tiles we expect - for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { - auto problem = host_problem_sizes[problem_idx]; - ProblemVisitor::possibly_transpose_problem(problem); - int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); - for (int i = 0; i < problem_tiles; ++i) { - EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); - EXPECT_EQ(i, host_visited_tiles[indices[idx]]); - ++idx; - } - } - - return true; - } - - bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { - cutlass::Status status = initialize(); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Initialization failed" << std::endl; - return false; - } - - dim3 grid(threadblock_count, 1, 1); - dim3 block(ProblemVisitor::kThreadCount, 1, 1); - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - cutlass::Kernel<<>>(params); - - cudaError_t result = cudaGetLastError(); - if (result != cudaSuccess) { - std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; - return false; - } - - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; - return false; - } - - visited_problems.copy_to_host(host_visited_problems.data()); - visited_tiles.copy_to_host(host_visited_tiles.data()); - - if (skip_tile_check) { - return true; - } - - return verify(); - } -}; - -template -struct TestbedGroupedRank2KScheduler { - - using BaselinePV = BaselineProblemVisitor, - ThreadblockShape, - PrefetchTileCount, - ThreadCount, - FillModeC>; - - // - // Data members - // - - // Whether to skip checking that the tiles are visited as expected. This is useful - // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped - // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to - // exit early, but which are difficult to detect in tests without reimplementing - // this functionality. - bool skip_tile_check; - uint32_t seed; - int problem_count; - int threadblock_count; - std::vector problem_sizes_host; - - // - // Methods - // - - TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): - skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } - - /// Initializes data structures - void initialize(int32_t scale_factor) { - - // - // Choose random problem sizes - // - - problem_sizes_host.clear(); - problem_sizes_host.resize(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - int n = scale_factor * (rand() % 64) + 24; - - cutlass::gemm::GemmCoord problem( - n, - n, - scale_factor * (rand() % 64) + 24); - - problem_sizes_host.at(i) = problem; - } - } - - template - void compare_visitors(const ProblemVisitorRunner& baseline_runner) { - using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< - ThreadblockShape, - GroupScheduleMode_, - PrefetchTileCount, - ThreadCount, - FillModeC>; - ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); - EXPECT_TRUE(runner.run(skip_tile_check)); - - // Check that this problem visitor visits the same problems and tiles as the baseline - EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); - EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); - } - - template - void compare_visitors(const ProblemVisitorRunner& baseline_runner) { - // Compare the next visitor with the baseline visitor - compare_visitors(baseline_runner); - - // Recurse to compare the next visitors - compare_visitors(baseline_runner); - } - - /// Executes the test on all scheduler modes - void run(int problem_count, int threadblock_count, int scale_factor=8) { - - this->problem_count = problem_count; - this->threadblock_count = threadblock_count; - - // Initialize the problem - initialize(scale_factor); - - // Run the baseline visitor to which we will compare all other visitors - ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); - EXPECT_TRUE(baseline_runner.run(skip_tile_check)); - - compare_visitors(baseline_runner); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // device -} // gemm -} // test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h deleted file mode 100644 index bda2704b517ea95052e2c2060b50712b686344f6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h +++ /dev/null @@ -1,407 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for grouped GEMM problem visitors -*/ - -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -#include "cutlass/util/device_memory.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Use simple problem visitor as a baseline -template -struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { - using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - - struct SharedStorage {}; - - int32_t tile_count_sum; - SharedStorage &shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - BaselineProblemVisitor( - Params const ¶ms_, - SharedStorage &shared_storage_, - int32_t block_idx - ): Base(params_, block_idx), - shared_storage(shared_storage_) - { - cutlass::gemm::GemmCoord problem = this->problem_size(); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - tile_count_sum = this->tile_count(grid); - } - - CUTLASS_DEVICE - bool next_tile() { - if (this->tile_idx < tile_count_sum) { - return true; - } - - do { - ++this->problem_idx; - - if (this->problem_idx >= this->params.problem_count) { - return false; - } - - cutlass::gemm::GemmCoord problem = this->problem_size(); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - - this->problem_tile_start = tile_count_sum; - tile_count_sum += this->tile_count(grid); - - } while (tile_count_sum <= this->tile_idx); - - return true; - } - - static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count) { - return 0; - } - - static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count, - void* host_workspace_ptr) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct ProblemVisitorKernel { - struct SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - }; - - struct Params { - typename ProblemVisitor::Params problem_visitor_params; - int32_t* visited_problems_ptr; - int32_t* visited_tiles_ptr; - int32_t visits_per_block; - - Params(): - visited_problems_ptr(nullptr), - visited_tiles_ptr(nullptr), - visits_per_block(0) {} - - Params(typename ProblemVisitor::Params problem_visitor_params_, - int32_t* visited_problems_ptr_, - int32_t* visited_tiles_ptr_, - int32_t visits_per_block_): - problem_visitor_params(problem_visitor_params_), - visited_problems_ptr(visited_problems_ptr_), - visited_tiles_ptr(visited_tiles_ptr_), - visits_per_block(visits_per_block_) {} - }; - - CUTLASS_DEVICE - void operator()(const Params& params, SharedStorage &shared_storage) { - int32_t store_offset = params.visits_per_block * blockIdx.x; - ProblemVisitor problem_visitor(params.problem_visitor_params, - shared_storage.problem_visitor, - blockIdx.x); - - while (problem_visitor.next_tile()) { - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - if (threadIdx.x == 0) { - params.visited_problems_ptr[store_offset] = problem_idx; - params.visited_tiles_ptr[store_offset] = threadblock_idx; - ++store_offset; - } - problem_visitor.advance(gridDim.x); - } - } -}; - -template -struct ProblemVisitorRunner { - using BaseKernel = ProblemVisitorKernel; - using Params = typename BaseKernel::Params; - - Params params; - std::vector host_problem_sizes; - int32_t problem_count; - int32_t threadblock_count; - int32_t visits_per_block; - cutlass::DeviceAllocation visited_problems; - cutlass::DeviceAllocation visited_tiles; - cutlass::DeviceAllocation device_problem_sizes; - cutlass::DeviceAllocation workspace; - std::vector host_visited_problems; - std::vector host_visited_tiles; - - ProblemVisitorRunner(const std::vector& host_problem_sizes_, - int32_t threadblock_count_): - host_problem_sizes(host_problem_sizes_), - problem_count(int32_t(host_problem_sizes_.size())), - threadblock_count(threadblock_count_) {} - - /// Initializes GEMM state from arguments. - cutlass::Status initialize() { - size_t workspace_bytes = ProblemVisitor::get_workspace_size( - host_problem_sizes.data(), - problem_count, - threadblock_count); - - workspace.reset(workspace_bytes); - std::vector host_workspace(workspace_bytes); - - int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); - - ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, - threadblock_count, host_workspace.data()); - - workspace.copy_from_host(host_workspace.data(), workspace_bytes); - - device_problem_sizes.reset(problem_count); - device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); - - visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; - int32_t total_visits = visits_per_block * threadblock_count; - - visited_problems.reset(total_visits); - visited_tiles.reset(total_visits); - host_visited_problems.resize(total_visits); - host_visited_tiles.resize(total_visits); - - cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } - - result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } - - typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); - params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); - - return cutlass::Status::kSuccess; - } - - bool verify() { - // Sort by problem size and then by threadblock_idx - std::vector indices(host_visited_problems.size()); - std::iota(indices.begin(), indices.end(), 0); - - std::stable_sort(indices.begin(), indices.end(), - [&](int32_t i1, int32_t i2) { - if (host_visited_problems[i1] == host_visited_problems[i2]) { - return host_visited_tiles[i1] < host_visited_tiles[i2]; - } - return host_visited_problems[i1] < host_visited_problems[i2]; - }); - - int32_t idx = 0; - - // Skip any entries that were not visited - while (host_visited_problems[indices[idx]] == -1) { - ++idx; - } - - // Check that each problem visited has the tiles we expect - for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { - auto problem = host_problem_sizes[problem_idx]; - ProblemVisitor::possibly_transpose_problem(problem); - int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); - for (int i = 0; i < problem_tiles; ++i) { - EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); - EXPECT_EQ(i, host_visited_tiles[indices[idx]]); - ++idx; - } - } - - return true; - } - - bool run(cudaStream_t stream = nullptr) { - cutlass::Status status = initialize(); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Initialization failed" << std::endl; - return false; - } - - dim3 grid(threadblock_count, 1, 1); - dim3 block(ProblemVisitor::kThreadCount, 1, 1); - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - cutlass::Kernel<<>>(params); - - cudaError_t result = cudaGetLastError(); - if (result != cudaSuccess) { - std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; - return false; - } - - result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; - return false; - } - - visited_problems.copy_to_host(host_visited_problems.data()); - visited_tiles.copy_to_host(host_visited_tiles.data()); - - return verify(); - } -}; - -template -struct TestbedGroupedGemmScheduler { - - using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; - using BaselinePV = BaselineProblemVisitor; - - // - // Data members - // - uint32_t seed; - int problem_count; - int threadblock_count; - std::vector problem_sizes_host; - - // - // Methods - // - - TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): - seed(seed_) { srand(seed); } - - /// Initializes data structures - void initialize(int32_t scale_factor) { - - // - // Choose random problem sizes - // - - problem_sizes_host.clear(); - problem_sizes_host.resize(problem_count); - - for (int32_t i = 0; i < problem_count; ++i) { - - cutlass::gemm::GemmCoord problem( - scale_factor * (rand() % 64) + 24, - scale_factor * (rand() % 64) + 24, - scale_factor * (rand() % 64) + 24); - - problem_sizes_host.at(i) = problem; - } - } - - template - void compare_visitors(const ProblemVisitorRunner& baseline_runner) { - using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< - ThreadblockShape, - GroupScheduleMode_, - PrefetchTileCount, - ThreadCount, - Transpose>; - ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); - EXPECT_TRUE(runner.run()); - - // Check that this problem visitor visits the same problems and tiles as the baseline - EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); - EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); - } - - template - void compare_visitors(const ProblemVisitorRunner& baseline_runner) { - // Compare the next visitor with the baseline visitor - compare_visitors(baseline_runner); - - // Recurse to compare the next visitors - compare_visitors(baseline_runner); - } - - /// Executes the test on all scheduler modes - void run(int problem_count, int threadblock_count, int scale_factor=8) { - - this->problem_count = problem_count; - this->threadblock_count = threadblock_count; - - // Initialize the problem - initialize(scale_factor); - - // Run the baseline visitor to which we will compare all other visitors - ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); - EXPECT_TRUE(baseline_runner.run()); - - compare_visitors(baseline_runner); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // device -} // gemm -} // test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h deleted file mode 100644 index 2a5956000db8e8c05ea22538e58149998b03e3fc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h +++ /dev/null @@ -1,346 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/host_reorder.h" - -namespace test { -namespace gemm { -namespace device { - -//////////////////////////////////////////////////////////////////////////////// - -template -struct InterleavedTestbed { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - // - // Methods - // - - InterleavedTestbed( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, 2, -2, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Waives test if CUDA device is insufficient - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - // - // Allocate the GEMM workspace - // - - cutlass::HostTensor< - typename Gemm::ElementA, - typename Gemm::LayoutA> tensor_A(problem_size.mk()); - - cutlass::HostTensor< - typename Gemm::ElementB, - typename Gemm::LayoutB> tensor_B(problem_size.kn()); - - cutlass::HostTensor< - typename Gemm::ElementB, - typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> tensor_C(problem_size.mn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> tensor_D(problem_size.mn()); - - cutlass::HostTensor< - typename Gemm::ElementC, - typename Gemm::LayoutC> reference_D(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - cutlass::reorder_column( - tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); - - cutlass::reference::host::TensorCopy( - reference_D.host_view(), - tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B_reordered.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, - tensor_A.device_ref(), - tensor_B_reordered.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), - {alpha, beta} - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.initialize(arguments); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Verify - // - - cutlass::reference::host::Gemm< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, - ElementAccumulator, typename Gemm::Operator> - reference_gemm; - - reference_gemm( - problem_size, - alpha, - tensor_A.host_ref(), - tensor_B.host_ref(), - beta, - reference_D.host_ref(), - ElementAccumulator(0) - ); - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - bool passed = cutlass::reference::host::TensorEquals( - reference_D.host_view(), - tensor_D.host_view()); - - EXPECT_TRUE(passed); - if (!passed) { - - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nB_reordered =\n" << tensor_B_reordered.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view(); - } - - return passed; - } - - /// Runs a set of problem sizes - bool run_all() { - bool passed = true; - - int problem_size_m[] = { - InterleavedK, 256 + InterleavedK, 512 + InterleavedK - }; - - int problem_size_n[] = { - InterleavedK, 256 + InterleavedK, 512 + InterleavedK - }; - - int problem_size_k[] = { - InterleavedK, 256 + InterleavedK, 512 + InterleavedK - }; - - double problem_alpha[] = { - 1.0 - }; - - double problem_beta[] = { - 2.0 - }; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (double alpha : problem_alpha) { - for (double beta : problem_beta) { - - passed = run( - {m, n, k}, - ElementCompute(alpha), - ElementCompute(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - - return true; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h deleted file mode 100644 index 32452c30e05f64763a268195ae78138f26c09735..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h +++ /dev/null @@ -1,326 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm_planar_complex.h" -#include "cutlass/util/host_tensor_planar_complex.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace device { - -//////////////////////////////////////////////////////////////////////////////// - -template -class TestbedPlanarComplex { -public: - - using ElementA = typename Gemm::ElementA; - using LayoutA = typename Gemm::LayoutA; - using ElementB = typename Gemm::ElementB; - using LayoutB = typename Gemm::LayoutB; - using ElementC = typename Gemm::ElementC; - using LayoutC = typename Gemm::LayoutC; - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - using ElementAccumulator = typename Gemm::ElementAccumulator; - - // - // Data members - // - - cutlass::gemm::GemmCoord problem_size; - cutlass::HostTensorPlanarComplex tensor_A; - cutlass::HostTensorPlanarComplex tensor_B; - cutlass::HostTensorPlanarComplex tensor_C; - cutlass::HostTensorPlanarComplex tensor_D; - cutlass::HostTensorPlanarComplex tensor_D_ref; - - // - // Methods - // - - TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { - - tensor_A.reset({problem_size.m(), problem_size.k()}); - tensor_B.reset({problem_size.k(), problem_size.n()}); - tensor_C.reset({problem_size.m(), problem_size.n()}); - tensor_D.reset({problem_size.m(), problem_size.n()}); - tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); - } - - void initialize() { - - uint64_t seed = 1073; - - int scope_max = 8; - int scope_min = -8; - - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), seed, scope_max, scope_min, 0); - - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); - - cutlass::reference::host::TensorFillRandomUniform( - tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); - - cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); - cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - bool run( - cutlass::complex alpha = {1, 0}, - cutlass::complex beta = {0, 0}) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - initialize(); - - int batch_count = 1; - - ElementA *ptr_A = tensor_A.device_data(); - ElementB *ptr_B = tensor_B.device_data(); - ElementC *ptr_C = tensor_C.device_data(); - ElementC *ptr_D = tensor_D.device_data(); - - typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); - typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); - typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); - typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); - - int64_t imag_stride_A = tensor_A.imaginary_stride(); - int64_t imag_stride_B = tensor_B.imaginary_stride(); - int64_t imag_stride_C = tensor_C.imaginary_stride(); - int64_t imag_stride_D = tensor_D.imaginary_stride(); - - // - // Launch device kernel - // - - Gemm gemm_op; - - typename Gemm::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - batch_count, - {alpha, beta}, - ptr_A, - ptr_A + imag_stride_A, - ptr_B, - ptr_B + imag_stride_B, - ptr_C, - ptr_C + imag_stride_C, - ptr_D, - ptr_D + imag_stride_D, - lda, - lda, - ldb, - ldb, - ldc, - ldc, - ldd, - ldd - }; - - cutlass::Status status = gemm_op(args); - - EXPECT_EQ(status, cutlass::Status::kSuccess); - - cudaError_t error = cudaDeviceSynchronize(); - - tensor_D.sync_host(); - - // - // Compute reference - // - - cutlass::reference::host::GemmPlanarComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - Gemm::kTransformA, - tensor_B.host_ref(), - Gemm::kTransformB, - beta, - tensor_C.host_ref(), - tensor_D_ref.host_ref() - ); - - bool passed = cutlass::reference::host::TensorEquals( - tensor_D.host_view(), - tensor_D_ref.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - std::ofstream output("gemm_planar_complex.txt"); - - output - << "A:\n" << tensor_A.host_view() << "\n" - << "B:\n" << tensor_B.host_view() << "\n" - << "C:\n" << tensor_C.host_view() << "\n" - << "Reference:\n" - << tensor_D_ref.host_view() << "\n" - << "Computed:\n" - << tensor_D.host_view() << "\n"; - } - - return passed; - } -}; - -template -bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { - - TestbedPlanarComplex testbed(problem_size); - - return testbed.run(); -} - -template -bool TestAllGemmPlanarComplex() { - - int M[] = { - 16, 64, 72, 144, 264, 520, - }; - - int N[] = { - 16, 64, 72, 144, 248, 264, 520 - }; - - int K[] = { - 8, 64, 72, 96, 264, 520 - }; - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - cutlass::complex alpha_values[] = { - {ElementCompute(1.25), ElementCompute(-0.5)} - }; - - cutlass::complex beta_values[] = { - {ElementCompute(-2.25), ElementCompute(1.5)} - }; - - for (int m : M) { - for (int n : N) { - for (int k : K) { - - test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); - - for (auto const &alpha : alpha_values) { - for (auto const &beta : beta_values) { - - bool passed = testbed.run(alpha, beta); - if (!passed) { - return false; - } - } - } - } - } - } - - return true; -} - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h deleted file mode 100644 index 4d9f6743a45e5dc3a7b4ddd3e2a7b2abceffbb18..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h +++ /dev/null @@ -1,641 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide Rank 2k update interface - -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/blas3.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/error_metrics.h" -#include "cutlass/util/reference/host/rank_2k.h" -#include "cutlass/util/reference/host/rank_2k_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedRank2KUniversal { - - using ElementA = typename Rank2K::ElementA; - using ElementB = typename Rank2K::ElementB; - using ElementC = typename Rank2K::ElementC; - using ElementAccumulator = typename Rank2K::ElementAccumulator; - using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - TestbedRank2KUniversal( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - - EXPECT_TRUE(false) << "Input distribution not implemented"; - return false; - } - - return true; - } - - - /// Helper to initialize a tensor view - template - bool initialize_symmetric_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillSymmetricRandomUniform( - view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillSymmetricRandomGaussian( - view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); - } - else { - - EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; - return false; - } - - return true; - } - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the Rank2K workspace - // - - tensor_A.resize(problem_size.mk()); - tensor_B.resize(problem_size.mk()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - - if (reference_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); - - bool passed = l2_norm < cutlass::MantissaInBits::error; - - return passed; - } - - /// Verifies the result is a Rank2K - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - cutlass::reference::host::Rank2KComplex< - typename Rank2K::ElementA, typename Rank2K::LayoutA, - typename Rank2K::ElementB, typename Rank2K::LayoutB, - typename Rank2K::ElementC, typename Rank2K::LayoutC, - ElementCompute, ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - Rank2K::kTransformA, - tensor_B.host_ref(), - Rank2K::kTransformB, - beta, - tensor_C.host_ref(), - reference_D.host_ref(), - ElementAccumulator(0), - Rank2K::kFillModeC, - Rank2K::kBlasMode - ); - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Rank2K::Rank2Kkernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 - std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size - << " alpha: " << ElementCompute(alpha) - << " beta: " << ElementCompute(beta) << std::endl; -#endif - - this->initialize(problem_size); - - // - // Initialize the Rank2K operator - // - - typename Rank2K::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - problem_size.n() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0) - }; - - Rank2K rank2k_op; - - size_t workspace_size = Rank2K::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the Rank2K - // - - status = rank2k_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - //if (true) { - if (!passed) { - std::stringstream fname; - - fname << "error_Rank2k_device_" - << "fill_mode_c_" - << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : - (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) - << "mnk_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Rank2K::ThreadblockShape::kM << "x" - << Rank2K::ThreadblockShape::kN << "x" - << Rank2K::ThreadblockShape::kK << "_" - << Rank2K::WarpShape::kM << "x" - << Rank2K::WarpShape::kN << "x" - << Rank2K::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nD reference:\n" << reference_D.host_view() << "\n" - << "\nD computed:\n" << tensor_D.host_view() << "\n"; - - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestRank2kUniversal( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedRank2KUniversal testbed; - - using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -template -bool TestAllRank2KUniversal() { - bool passed = true; - - - int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); - - int const kAlignment = cutlass::platform::is_same< - typename Rank2K::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = kAlignmentM; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value - ? 4 : kAlignment; - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int problem_size_k[] = { - kAlignmentK, - Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, - Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1 // Just running one batch for now (removing 2, 3, 5, 7) - }; - - double problem_alpha[] = { - 1.0, 3.25 - }; - - double problem_beta[] = { - 0.0, 2.15 - }; - - using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int batch_count : batch_counts) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - - // skip very small K problems - //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { - // continue; - //} - } - - cutlass::gemm::GemmCoord problem_size(n, n, k); - - TestbedRank2KUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -template -bool TestAllRank2KHermitianUniversal() { - bool passed = true; - - using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; - using ElementAccumulator = typename Rank2K::ElementAccumulator; - - int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); - - int const kAlignment = cutlass::platform::is_same< - typename Rank2K::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = kAlignmentM; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value - ? 4 : kAlignment; - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int problem_size_k[] = { - kAlignmentK, - Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, - Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1 // Just running one batch for now (removing 2, 3, 5, 7) - }; - - /* Complex alpha for HER2K */ - ElementAccumulator problem_alpha[] = { - {1.0}, - {1.25, 3.25}, - {-0.25, -2.25} - }; - - ElementAccumulator problem_beta[] = { - 0.0, -2.25 - }; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int batch_count : batch_counts) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - - // skip very small K problems - //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { - // continue; - //} - } - - cutlass::gemm::GemmCoord problem_size(n, n, k); - - TestbedRank2KUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - alpha, - beta - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h deleted file mode 100644 index cb46528a049ae1254d0492b6235821210e47b957..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h +++ /dev/null @@ -1,511 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide Rank 2k update interface - -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/blas3.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/error_metrics.h" -#include "cutlass/util/reference/host/rank_k_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedRank2KUniversal { - - using ElementA = typename RankK::ElementA; - using ElementC = typename RankK::ElementC; - using ElementAccumulator = typename RankK::ElementAccumulator; - using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - TestbedRank2KUniversal( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - - EXPECT_TRUE(false) << "Input distribution not implemented"; - return false; - } - - return true; - } - - - /// Helper to initialize a tensor view - template - bool initialize_symmetric_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillSymmetricRandomUniform( - view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillSymmetricRandomGaussian( - view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); - } - else { - - EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; - return false; - } - - return true; - } - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the RankK workspace - // - - tensor_A.resize(problem_size.mk()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); - tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - - if (reference_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); - - bool passed = l2_norm < cutlass::MantissaInBits::error; - - return passed; - } - - /// Verifies the result is a RankK - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - cutlass::reference::host::Rank2KComplex< - typename RankK::ElementA, typename RankK::LayoutA, - typename RankK::ElementC, typename RankK::LayoutC, - ElementCompute, ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - RankK::kTransformA, - beta, - tensor_C.host_ref(), - reference_D.host_ref(), - ElementAccumulator(0), - RankK::kFillModeC, - RankK::kBlasMode - ); - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename RankK::RankKkernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 - std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size - << " alpha: " << ElementCompute(alpha) - << " beta: " << ElementCompute(beta) << std::endl; -#endif - - this->initialize(problem_size); - - // - // Initialize the RankK operator - // - - typename RankK::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0) - }; - - RankK rank2k_op; - - size_t workspace_size = RankK::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the RankK - // - - status = rank2k_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - //if (true) { - if (!passed) { - std::stringstream fname; - - fname << "error_RankK_device_" - << "fill_mode_c_" - << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : - (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) - << "mnk_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << RankK::ThreadblockShape::kM << "x" - << RankK::ThreadblockShape::kN << "x" - << RankK::ThreadblockShape::kK << "_" - << RankK::WarpShape::kM << "x" - << RankK::WarpShape::kN << "x" - << RankK::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nD reference:\n" << reference_D.host_view() << "\n" - << "\nD computed:\n" << tensor_D.host_view() << "\n"; - - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestRank2kUniversal( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedRank2KUniversal testbed; - - using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -template -bool TestAllRankKUniversal() { - bool passed = true; - - - int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); - int const kAlignmentN = 128 / kMinimumOperandElementSize; - int const kAlignmentK = 128 / kMinimumOperandElementSize; - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int problem_size_k[] = { - kAlignmentK, - RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, - RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1 // Just running one batch for now (removing 2, 3, 5, 7) - }; - - double problem_alpha[] = { - 1.0 - }; - - double problem_beta[] = { - 2.0 - }; - - - using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int batch_count : batch_counts) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - } - - cutlass::gemm::GemmCoord problem_size(n, n, k); - - TestbedRank2KUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h deleted file mode 100644 index 0a01a6a32ee2db84f2e890059423cd6b8477f766..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h +++ /dev/null @@ -1,238 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/core_io.h" - -#include "testbed.h" - - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// List of Gemm internal paramters this testbed supports user verification -// -enum class ParameterID { - - // Threadblock-level parameters - kSmemASize, - kSmemBSize, - - // Warp-level parameters - kWarpFragmentASize, - kWarpFragmentBSize, - kWarpFragmentCSize, - kInvalid -}; - -struct Reference { - ParameterID parameter_id; - - union { - int value; - - struct { - int m, n, k; - } gemm_shape; - - struct { - int row, column; - } matrix_shape; - }; - - std::string error_msg; - - Reference( - ParameterID parameter_id_, - int value_=-1, - std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} -}; - - -template -struct TestbedSanity { - - // - // Type definitions (All Gemm types top down) - // - - // Unpacking Gemm types in the following order - // Kernel-level > Threadblock-level > Warp-level > Instruction-level - - // kernel-level cutlass Gemm - using GemmKernel = typename Gemm::GemmKernel; - - // - // Threadblock-level gemm types - // - using MmaThreadBlock = typename GemmKernel::Mma; - - // Threadblock-level gemm shape covering one stage - using ThreadblockShape = typename MmaThreadBlock::Shape; - - // Shared memory size covering all stages - using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; - using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; - using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; - using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; - - - /// Number of stages - static int const kStages = MmaThreadBlock::Base::kStages; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; - - - // - // Warp-level gemm types - // - - // Warp-level gemm operator - using MmaWarp = typename MmaThreadBlock::Operator; - - // Warp-level gemm shape covering all kgroups - using WarpShape = typename MmaWarp::Shape; - - // Warp-level framents holding operands A & B operand and destination C - using WarpFragmentA = typename MmaWarp::FragmentA; - using WarpFragmentB = typename MmaWarp::FragmentB; - using WarpFragmentC = typename MmaWarp::FragmentC; - - // - // Instruction-level gemm types - // - - // Instruction-level gemm operator - using MmaInstruction = typename MmaWarp::Policy::Operator; - - // Instruction shape - using InstructionShape = typename MmaInstruction::Shape; - - // Instruction-level framents holding operands A & B operand and destination C - using InstructionFragmentA = typename MmaInstruction::FragmentA; - using InstructionFragmentB = typename MmaInstruction::FragmentB; - using InstructionFragmentC = typename MmaInstruction::FragmentC; - - // - // Testbed types - // - - // Vector of values holding user provided reference - using ReferenceVector = std::vector; - - // - // Data members - // - ReferenceVector references; - - // - // Methods - // - - TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } - - // verify all parameter in ReferenceVector - bool verify() { - for(auto ref : references) - verify_parameter(ref); - return true; - } - - // verify parameter of type Reference - void verify_parameter(Reference const& ref) { - switch(ref.parameter_id) { - case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; - case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; - case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; - } - } - -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Overload output operators for TesbedSanity -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -template -std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { - - - out << "Gemm internal parameters" << std::endl - << " Threadblock-level parameters:" << std::endl - << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl - << " kStages = " << TestbedSanity::kStages << std::endl - << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl - <<" Shared memory sizes:" << std::endl - <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl - <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl - <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl - <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl - <<" Warp-level parameters" << std::endl - <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl - <<" Fragment sizes:" << std::endl - <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl - <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl - <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl - <<" Instruction-level parameters" << std::endl - <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl - <<" Fragment sizes:" << std::endl - <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl - <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl - <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; - - return out; -} - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h deleted file mode 100644 index a95bf996bac337b44da616dc9fbf9c9bdb2a625c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h +++ /dev/null @@ -1,487 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface - - Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/host_reorder.h" -#include "cutlass/util/host_uncompress.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SparseTestbed { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - static int const kSparse = Gemm::GemmKernel::kSparse; - static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; - static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; - static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; - - using ElementE = typename Gemm::GemmKernel::ElementE; - using LayoutE = cutlass::layout::RowMajor; - using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - cutlass::Distribution::Kind init_E; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_A_uncompressed; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - cutlass::HostTensor tensor_E; - cutlass::HostTensor tensor_E_reordered; - - // - // Methods - // - - SparseTestbed( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080) - : init_A(init_A_), - init_B(init_B_), - init_C(init_C_), - init_E(init_E_), - seed(seed_) {} - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 1; - scope_min = -1; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); - tensor_A_uncompressed.resize(problem_size.mk()); - tensor_B.resize(problem_size.kn()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - tensor_E.resize(cutlass::make_Coord( - problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); - tensor_E_reordered.resize(cutlass::make_Coord( - problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - if (init_E == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomSparseMeta( - tensor_E.host_view(), seed, kMetaSizeInBits); - } else if (init_E == cutlass::Distribution::Identity) { - uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; - cutlass::reference::host::TensorFill(tensor_E.host_view(), - (ElementE)(content)); - } else { - EXPECT_TRUE(false); - } - - cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), - {problem_size.m(), problem_size.n(), - problem_size.k() / kSparse / kElementsPerElementE}); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - tensor_E_reordered.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - - if (reference_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\nE =\n" << tensor_E.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view(); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - - cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), - tensor_E.host_ref(), problem_size.m(), problem_size.k()); - - cutlass::reference::host::Gemm< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, - ElementCompute, - ElementAccumulator, typename Gemm::Operator> - reference_gemm; - - reference_gemm( - problem_size, - alpha, - tensor_A_uncompressed.host_ref(), - tensor_B.host_ref(), - beta, - reference_D.host_ref(), - ElementAccumulator(0) - ); - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - int split_k_slices = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - split_k_slices, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - tensor_E_reordered.device_data(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0), - tensor_E_reordered.layout().stride(0) - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - // This failure is likely due to insufficient device capabilities. Waive the test. - if (status != cutlass::Status::kSuccess) { - return true; - } - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < -bool TestAllSparseGemm() { - bool passed = true; - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) - // because of the reordering of operand E - int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), - kMinimumOperandElementSize); - - int const kAlignmentN = 128 / kMinimumOperandElementSize; - - int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; - - int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; - - int problem_size_k[] = {Gemm::ThreadblockShape::kK * 8}; - - int split_k_slices[] = { - 1, 2 - }; - - double problem_alpha[] = { - 1 - }; - - double problem_beta[] = { - 2.0 - }; - - SparseTestbed testbed; - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int split_k : split_k_slices) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - cutlass::gemm::GemmCoord problem_size(m, n, k); - - passed = testbed.run( - problem_size, - split_k, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h deleted file mode 100644 index 8fa4a85505316d08f1d050702b78448f8fae8565..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h +++ /dev/null @@ -1,218 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "testbed.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedSplitK : public Testbed { - - using Base = Testbed; - - using ElementCompute = typename Base::ElementCompute; - - // - // Methods - // - - TestbedSplitK( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - Base(init_A_, init_B_, init_C_, seed_) { } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmCoord problem_size, - int split_k_slices, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - problem_size, - this->tensor_A.device_ref(), - this->tensor_B.device_ref(), - this->tensor_C.device_ref(), - this->tensor_D.device_ref(), - {alpha, beta}, - split_k_slices - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - - // - // Verify - // - - return this->verify(problem_size, alpha, beta); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -bool TestAllGemmSplitK() { - bool passed = true; - - cutlass::gemm::GemmCoord problem_sizes[] = { - {8, 8, 2048}, - {8, 8, 2056}, - {264, 72, 520}, - {264, 520, 120}, - {264, 520, 264} - }; - - int split_k_slices[] = { - 1, 2, 4, 5, 7 - }; - - double problem_alpha[] = { - 0.5 - }; - - double problem_beta[] = { - 2.0 - }; - - using Testbed = TestbedSplitK; - using ElementCompute = typename Testbed::ElementCompute; - - Testbed testbed; - - for (auto problem_size : problem_sizes) { - for (int split_k_count : split_k_slices) { - for (double alpha : problem_alpha) { - for (double beta : problem_beta) { - - passed = testbed.run( - problem_size, - split_k_count, - ElementCompute(alpha), - ElementCompute(beta) - ); - - if (!passed) { - std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; - return false; - } - } - } - } - } - - EXPECT_TRUE(passed); - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h deleted file mode 100644 index b7a57f7eb0ca73c23460e5a9ce1301061c2cc286..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h +++ /dev/null @@ -1,592 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide Symm update interface - -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/blas3.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/error_metrics.h" -#include "cutlass/util/reference/host/symm.h" -#include "cutlass/util/reference/host/symm_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedSymmUniversal { - - using ElementA = typename Symm::ElementA; - using ElementB = typename Symm::ElementB; - using ElementC = typename Symm::ElementC; - using ElementAccumulator = typename Symm::ElementAccumulator; - using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - TestbedSymmUniversal( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - - EXPECT_TRUE(false) << "Input distribution not implemented"; - return false; - } - - return true; - } - - - /// Helper to initialize a tensor view - template - bool initialize_symmetric_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillSymmetricRandomUniform( - view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillSymmetricRandomGaussian( - view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); - } - else { - - EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; - return false; - } - - return true; - } - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the Symm workspace - // - - if (Symm::kSideModeA == cutlass::SideMode::kLeft) { - tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); - } - else if (Symm::kSideModeA == cutlass::SideMode::kRight) { - tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); - } - - tensor_B.resize(problem_size.mn()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - - EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - if (tensor_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - - if (reference_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); - - bool passed = l2_norm < cutlass::MantissaInBits::error; - - return passed; - } - - /// Verifies the result is a Symm - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - - using HostReference = typename cutlass::platform::conditional< - (cutlass::platform::is_same - >::value || - cutlass::platform::is_same - >::value - ), - cutlass::reference::host::SymmComplex< - typename Symm::ElementA, typename Symm::LayoutA, - Symm::kSideModeA, Symm::kFillModeA, - typename Symm::ElementB, typename Symm::LayoutB, - typename Symm::ElementC, typename Symm::LayoutC, - ElementCompute, - ElementAccumulator, - Symm::kBlasMode>, - cutlass::reference::host::Symm< - typename Symm::ElementA, typename Symm::LayoutA, - Symm::kSideModeA, Symm::kFillModeA, - typename Symm::ElementB, typename Symm::LayoutB, - typename Symm::ElementC, typename Symm::LayoutC, - ElementCompute, - ElementAccumulator> - >::type; - - - HostReference reference_symm; - - reference_symm( - problem_size, - alpha, - tensor_A.host_ref(), - tensor_B.host_ref(), - beta, - tensor_C.host_ref(), - reference_D.host_ref(), - ElementAccumulator(0) - ); - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Symm::SymmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 - std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size - << " alpha: " << ElementCompute(alpha) - << " beta: " << ElementCompute(beta) << std::endl; -#endif - - this->initialize(problem_size); - - // - // Initialize the Symm operator - // - - int batch_stride_A; - if (Symm::kSideModeA == cutlass::SideMode::kLeft) - batch_stride_A = problem_size.m()*problem_size.m(); - if (Symm::kSideModeA == cutlass::SideMode::kRight) - batch_stride_A = problem_size.n()*problem_size.n(); - - typename Symm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - batch_stride_A, - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0) - }; - - Symm symm_op; - - size_t workspace_size = Symm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = symm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the Symm - // - - status = symm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - //if (true) { - if (!passed) { - std::stringstream fname; - - fname << "error_" - << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) - << "device_" - << "fill_mode_a_" - << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : - (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) - << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : - (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) - << "mnk_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Symm::ThreadblockShape::kM << "x" - << Symm::ThreadblockShape::kN << "x" - << Symm::ThreadblockShape::kK << "_" - << Symm::WarpShape::kM << "x" - << Symm::WarpShape::kN << "x" - << Symm::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "alpha: " << ElementCompute(alpha) << "\n" - << "beta: " << ElementCompute(beta) << "\n" - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nD reference:\n" << reference_D.host_view() << "\n" - << "\nD computed:\n" << tensor_D.host_view() << "\n"; - - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestsymmUniversal( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedSymmUniversal testbed; - - using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -template -bool TestAllSymmUniversal() { - bool passed = true; - - - int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); - - int const kAlignment = cutlass::platform::is_same< - typename Symm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = kAlignmentM; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value - ? 4 : kAlignment; - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_m[] = { - kAlignmentK, - Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, - Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1 // Just running one batch for now (removing 2, 3, 5, 7) - }; - - double problem_alpha[] = { - 1.0, 3.0 - }; - - double problem_beta[] = { - 0, 2.0 - }; - - - using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int batch_count : batch_counts) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - int k = 0; - if (Symm::kSideModeA == cutlass::SideMode::kLeft) - k = m; - else if (Symm::kSideModeA == cutlass::SideMode::kRight) - k = n; - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - - #if 0 - // skip very small K problems - if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { - continue; - } - #endif - } - - cutlass::gemm::GemmCoord problem_size(m, n, k); - - TestbedSymmUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h deleted file mode 100644 index b30acfed6bba547986efd3afa8eb829be2a255e4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h +++ /dev/null @@ -1,606 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide TRMM interface - - -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/blas3.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/error_metrics.h" -#include "cutlass/util/reference/host/trmm.h" -#include "cutlass/util/reference/host/trmm_complex.h" -#include "cutlass/core_io.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedTrmmUniversal { - - using ElementA = typename Trmm::ElementA; - using ElementB = typename Trmm::ElementB; - using ElementC = typename Trmm::ElementC; - using ElementAccumulator = typename Trmm::ElementAccumulator; - using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_D; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - TestbedTrmmUniversal( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - - /// Helper to initialize a tensor view - template - bool initialize_symmetric_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int mantissa_in_bits) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillSymmetricRandomUniform( - view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillSymmetricRandomGaussian( - view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) - template - bool initialize_pad_diagonal_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed, - int alignment) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillPadDiagonalRandomUniform( - view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the TRMM workspace - // - - if (Trmm::kSideMode == cutlass::SideMode::kLeft) { - tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); - } - else if (Trmm::kSideMode == cutlass::SideMode::kRight) { - tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); - } - - tensor_B.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - - //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); - //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - - if (tensor_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - - if (reference_D.size() > 1) - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); - - bool passed = l2_norm < cutlass::MantissaInBits::error; - - return passed; - } - - /// Verifies the result is a TRMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha) { - - // - // Verify - // - - using HostReference = typename cutlass::platform::conditional< - (cutlass::platform::is_same - >::value || - cutlass::platform::is_same - >::value - ), - cutlass::reference::host::TrmmComplex< - typename Trmm::ElementA, typename Trmm::LayoutA, - Trmm::kTransformA, - Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, - typename Trmm::ElementB, typename Trmm::LayoutB, - Trmm::kTransformB, - typename Trmm::ElementC, typename Trmm::LayoutC, - ElementCompute, - ElementAccumulator>, - cutlass::reference::host::Trmm< - typename Trmm::ElementA, typename Trmm::LayoutA, - Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, - typename Trmm::ElementB, typename Trmm::LayoutB, - typename Trmm::ElementC, typename Trmm::LayoutC, - ElementCompute, - ElementAccumulator> - >::type; - - - HostReference reference_trmm; - - reference_trmm( - problem_size, - alpha, - tensor_A.host_ref(), - tensor_B.host_ref(), - reference_D.host_ref(), - ElementAccumulator(0) - ); - - return compare_reference(problem_size, alpha); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Trmm::TrmmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1)) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - -#if 0 - std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size - << " alpha: " << ElementCompute(alpha) << std::endl; -#endif - - this->initialize(problem_size); - - // - // Initialize the TRMM operator - // - - int batch_stride_A; - if (Trmm::kSideMode == cutlass::SideMode::kLeft) - batch_stride_A = problem_size.m()*problem_size.m(); - if (Trmm::kSideMode == cutlass::SideMode::kRight) - batch_stride_A = problem_size.n()*problem_size.n(); - - typename Trmm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_D.device_data(), - batch_stride_A, - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_D.layout().stride(0) - }; - - Trmm trmm_op; - - size_t workspace_size = Trmm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the TRMM - // - - status = trmm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - bool passed = this->verify(problem_size, alpha); - - if (!passed) { - std::stringstream fname; - - fname << "error_Trmm_device_" - << "fill_mode_" - << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : - (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) - << "side_mode_" - << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : - (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) - << "mnk_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Trmm::ThreadblockShape::kM << "x" - << Trmm::ThreadblockShape::kN << "x" - << Trmm::ThreadblockShape::kK << "_" - << Trmm::WarpShape::kM << "x" - << Trmm::WarpShape::kN << "x" - << Trmm::WarpShape::kK << ".txt"; - - std::cout << fname.str() << std::endl; - - std::ofstream results(fname.str()); - - results << problem_size << std::endl; - - results - << "\nA:\n" << tensor_A.host_view() << "\n" - << "\nB:\n" << tensor_B.host_view() << "\n" - << "\nD reference:\n" << reference_D.host_view() << "\n" - << "\nD computed:\n" << tensor_D.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestTrmmUniversal( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0) { - - bool passed = true; - - TestbedTrmmUniversal testbed; - - using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha) - ); - - return passed; -} - -template -bool TestAllTrmmUniversal() { - bool passed = true; - - int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); - - int const kAlignment = cutlass::platform::is_same< - typename Trmm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = kAlignmentM; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value - ? 4 : kAlignment; - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_m[] = { - kAlignmentK, - Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, - Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1 // Just running one batch for now (removing 2, 3, 5, 7) - }; - - double problem_alpha[] = { - 1.0, 2.0 - }; - - using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int batch_count : batch_counts) { - for (auto alpha : problem_alpha) { - - int k = 0; - if (Trmm::kSideMode == cutlass::SideMode::kLeft) - k = m; - else if (Trmm::kSideMode == cutlass::SideMode::kRight) - k = n; - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - -#if 0 - // skip very small K problems - if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { - continue; - } -#endif - } - - cutlass::gemm::GemmCoord problem_size(m, n, k); - - TestbedTrmmUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h deleted file mode 100644 index 00368a5e8eebc128719f64069583010c83dc0c1f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h +++ /dev/null @@ -1,553 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/gemm_complex.h" - -#include "testbed_utils.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TestbedUniversal { - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - - /// Initialization - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_B; - cutlass::Distribution::Kind init_C; - uint64_t seed; - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor reference_D; - - // - // Methods - // - - TestbedUniversal( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 2080 - ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - - /// Helper to initialize a tensor view - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { - - if (dist_kind == cutlass::Distribution::Uniform) { - - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; - bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = is_unsigned_int ? 2 : 1; - scope_min = is_unsigned_int ? 0 : -1; - } else if (bits_output == 16) { - constexpr auto u8_bf16 = - (cutlass::platform::is_same::value && - cutlass::platform::is_same::value) || - (cutlass::platform::is_same::value && - cutlass::platform::is_same::value); - scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); - scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential( - view.data(), view.capacity()); - } - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - - tensor_A.resize(problem_size.mk()); - tensor_B.resize(problem_size.kn()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); - - EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); - EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); - EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - cutlass::Coord<2> origin(0); - tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); - tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); - tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D.sync_device(); - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); - - EXPECT_TRUE(passed) << " mismatched reference"; - - if (!passed) { - - /* - - std::stringstream fname; - - fname << "error_Gemm_device_" - << problem_size.m() << "x" - << problem_size.n() << "x" - << problem_size.k() << "_" - << Gemm::ThreadblockShape::kM << "x" - << Gemm::ThreadblockShape::kN << "x" - << Gemm::ThreadblockShape::kK << "_" - << Gemm::WarpShape::kM << "x" - << Gemm::WarpShape::kN << "x" - << Gemm::WarpShape::kK << ".txt"; - - std::ofstream file(fname.str()); - */ - - std::ofstream file("testbed_universal_errors.txt"); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << tensor_A.host_view() - << "\nB =\n" << tensor_B.host_view() - << "\nC =\n" << tensor_C.host_view() - << "\n\nReference =\n" << reference_D.host_view() - << "\nComputed =\n" << tensor_D.host_view(); - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - // - // Verify - // - - cutlass::reference::host::GemmComplex< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, - ElementCompute, ElementAccumulator - >( - problem_size, - alpha, - tensor_A.host_ref(), - Gemm::kTransformA, - tensor_B.host_ref(), - Gemm::kTransformB, - beta, - tensor_C.host_ref(), - reference_D.host_ref(), - ElementAccumulator(0) - ); - - if (Relu) { - for (int i = 0; i < problem_size.m(); ++i) { - for (int j = 0; j < problem_size.n(); ++j) { - reference_D.at(cutlass::MatrixCoord(i, j)) = - ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) - ? (typename Gemm::ElementC)0 - : reference_D.at(cutlass::MatrixCoord(i, j)); - } - } - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - - return true; - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) - { -/* - std::cout << "\n-----------------------\n"; - std::cout << "mode: " << (int) mode << "\n"; - std::cout << "problem size: " << problem_size << "\n"; - std::cout << "batch_count: " << batch_count << "\n"; - std::cout << "alpha: " << alpha << "\n"; - std::cout << "beta: " << beta << "\n"; - std::cout << "-----------------------\n\n"; -*/ - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::Arguments arguments{ - mode, - problem_size, - batch_count, - {alpha, beta}, - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_D.device_data(), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - tensor_A.layout().stride(0), - tensor_B.layout().stride(0), - tensor_C.layout().stride(0), - tensor_D.layout().stride(0) - }; - - Gemm gemm_op; - - size_t workspace_size = Gemm::get_workspace_size(arguments); - - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestGemmUniversal( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmUniversalMode mode, - int batch_count, - double alpha = 1.0, - double beta = 2.0) { - - bool passed = true; - - TestbedUniversal testbed; - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - return passed; -} - -template -bool TestAllGemmUniversal() { - bool passed = true; - - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - int const kAlignment = cutlass::platform::is_same< - typename Gemm::OperatorClass, - cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; - - // int8_t gemm alignment constraints - int const kAlignmentM = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentN = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value ? 4 : kAlignment; - - int const kAlignmentK = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - (cutlass::platform::is_same::value || - cutlass::platform::is_same::value) ? 4 : kAlignment; - - - - cutlass::gemm::GemmUniversalMode modes[] = { - cutlass::gemm::GemmUniversalMode::kGemm, - }; - - int problem_size_m[] = { - kAlignmentM, 512 - 3*kAlignmentM - }; - - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; - - int problem_size_k[] = { - kAlignmentK, - Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, - Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK - }; - - int batch_counts[] = { // may be interpretted as batch count or split-K slices - 1, 2, 3, 5, 7 - }; - - double problem_alpha[] = { - 1 - }; - - double problem_beta[] = { - 2.0 - }; - - - using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; - - for (cutlass::gemm::GemmUniversalMode mode : modes) { - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - for (int batch_count : batch_counts) { - - for (auto alpha : problem_alpha) { - for (auto beta : problem_beta) { - - if (mode == cutlass::gemm::GemmUniversalMode::kGemm || - mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { - - // skip very small K problems - if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { - continue; - } - } - - cutlass::gemm::GemmCoord problem_size(m, n, k); - - TestbedUniversal testbed; - - passed = testbed.run( - mode, - problem_size, - batch_count, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - if (!passed) { - return false; - } - } - } - } - } - } - } - } - - /* - // large problem with high coverage - for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { - TestbedUniversal testbed; - - cutlass::gemm::GemmCoord problem_size(72, 56, 8192); - - passed = testbed.run( - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - split_k_slices, - cutlass::from_real(1.0), - cutlass::from_real(2.0) - ); - - if (!passed) { - break; - } - } - */ - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h deleted file mode 100644 index 89ac33a1028061515d08d50fdb6cce7833ae88ce..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h +++ /dev/null @@ -1,53 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -inline char const *to_string(cutlass::Status status) { - - switch (status) { - case cutlass::Status::kSuccess: return "kSuccess"; - case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; - case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; - case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; - case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; - case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; - case cutlass::Status::kErrorInternal: return "kErrorInternal"; - case cutlass::Status::kInvalid: return "kInvalid"; - default: break; - } - return "invalid"; -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h deleted file mode 100644 index 8b5588f57c40c4e8f8d06adfa9f1e673350fb5e5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h +++ /dev/null @@ -1,609 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Testbed for running device-level GEMMs with absolute maximum calculation and scaling -*/ - -#pragma once - -#include -#include -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm_complex.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gemm.h" - -#include "testbed.h" -#include "testbed_sparse.h" -#include "testbed_utils.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" - -namespace test { -namespace gemm { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - typename GemmTestbed, - template class ActivationFunctor -> -struct TestbedWithAmax { - - static_assert(std::is_same_v> || std::is_same_v>); - static constexpr bool IsSparseTestbed = std::is_same_v>; - - using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; - using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; - using ElementAbsmax = typename Gemm::EpilogueOutputOp::ElementAbsmax; - - static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; - static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; - bool doScaleA; - bool doScaleB; - bool doScaleC; - - GemmTestbed underlying_testbed; - - cutlass::HostTensor tensor_Aux; - cutlass::HostTensor tensor_Vector; - cutlass::HostTensor tmp_D; - cutlass::HostTensor reference_D; - cutlass::HostTensor reference_Aux; - cutlass::HostTensor scale_A; - cutlass::HostTensor scale_B; - cutlass::HostTensor scale_C; - cutlass::HostTensor scale_D; - cutlass::HostTensor scale_Aux; - cutlass::HostTensor abs_max_Aux; - cutlass::HostTensor abs_max_D; - cutlass::HostTensor reference_abs_max_Aux; - cutlass::HostTensor reference_abs_max_D; - - // - // Methods - // - - TestbedWithAmax( - bool scaleA = true, - bool scaleB = true, - bool scaleC = true, - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform - ): - doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), - underlying_testbed(init_A_, init_B_, init_C_) { } - - /// Helper to initialize scaling factors - template - bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { - cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); - return true; - } - - /// Initializes data structures - void initialize(cutlass::gemm::GemmCoord problem_size) { - // - // Allocate the GEMM workspace - // - underlying_testbed.initialize(problem_size); - - tensor_Vector.resize({1, problem_size.n()}); - reference_D.resize(problem_size.mn(), false); - tmp_D.resize(problem_size.mn(), false); - - EXPECT_TRUE( - underlying_testbed.initialize_tensor(tensor_Vector.host_view(), underlying_testbed.init_C, underlying_testbed.seed + 2020) - ); - - // It is possible to randomly initialize to all zeros, so override this with non-zeros - // in the upper left corner of each operand. - cutlass::Coord<2> origin(0); - tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); - - cutlass::reference::host::TensorCopy(reference_D.host_view(), underlying_testbed.tensor_C.host_view()); - - tensor_Vector.sync_device(); - - int scale_bits = 2; - if (doScaleA) { - scale_A.resize({1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), underlying_testbed.seed + 2021, scale_bits)); - scale_A.sync_device(); - } - - if (doScaleB) { - scale_B.resize({1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), underlying_testbed.seed + 2022, scale_bits)); - scale_B.sync_device(); - } - - if (doScaleC) { - scale_C.resize({1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), underlying_testbed.seed + 2023, scale_bits)); - scale_C.sync_device(); - } - - if (kScaleOutput) { - scale_D.resize({1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), underlying_testbed.seed + 2024, scale_bits)); - scale_D.sync_device(); - - abs_max_D.resize({1, 1}); - cutlass::reference::host::TensorFill(abs_max_D.host_view()); - abs_max_D.sync_device(); - - reference_abs_max_D.resize({1, 1}); - } - - if (kScaleAux) { - tensor_Aux.resize(problem_size.mn()); - cutlass::reference::host::TensorFill(tensor_Aux.host_view()); - tensor_Aux.sync_device(); - - scale_Aux.resize({1, 1}); - EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), underlying_testbed.seed + 2025, scale_bits)); - scale_Aux.sync_device(); - - abs_max_Aux.resize({1, 1}); - cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); - abs_max_Aux.sync_device(); - - reference_Aux.resize(problem_size.mn(), false); - reference_abs_max_Aux.resize({1, 1}); - } - } - - /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - underlying_testbed.tensor_D.sync_host(); - - EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_A.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_B.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); - bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); - if (!passed) { - std::cout << "Comparison of D failed" << std::endl; - } - - if (kScaleAux) { - tensor_Aux.sync_host(); - abs_max_Aux.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); - if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { - passed = false; - std::cout << "Comparison of Aux failed" << std::endl; - } - if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { - passed = false; - std::cout << "Comparison of Aux absmax failed" << std::endl; - } - } - - if (kScaleOutput) { - abs_max_D.sync_host(); - EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); - if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { - passed = false; - std::cout << "Comparison of D absmax failed" << std::endl; - } - } - - EXPECT_TRUE(passed) << " mismatched reference"; - - if (!passed) { - - std::ofstream file("testbed_with_amax_errors.txt"); - - file - << "problem: " << problem_size - << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - - file - << "A =\n" << underlying_testbed.tensor_A.host_view() - << "\nB =\n" << underlying_testbed.tensor_B.host_view() - << "\nC =\n" << underlying_testbed.tensor_C.host_view() - << "\nVector =\n" << tensor_Vector.host_view() - << "\nScaleA = " << scale_A.host_view() - << "\nScaleB = " << scale_B.host_view() - << "\nScaleC = " << scale_C.host_view() - << "\nScaleD = " << scale_D.host_view() - << "\nScaleAux = " << scale_Aux.host_view() - << "\n\nReference D =\n" << reference_D.host_view() - << "\nComputed D =\n" << underlying_testbed.tensor_D.host_view(); - if (kScaleAux) { - file - << "\n\nReference Aux =\n" << reference_Aux.host_view() - << "\nComputed Aux =\n" << tensor_Aux.host_view() - << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() - << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); - } - if (kScaleOutput) { - file - << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() - << "\nComputed Absmax D = " << abs_max_D.host_view(); - } - } - - return passed; - } - - /// Verifies the result is a GEMM - bool verify( - cutlass::gemm::GemmCoord problem_size, - ElementCompute alpha, - ElementCompute beta) { - - cutlass::Coord<2> origin(0); - ElementCompute scaled_alpha = alpha; - if (doScaleA) { - scaled_alpha *= scale_A.host_view().at(origin); - } - if (doScaleB) { - scaled_alpha *= scale_B.host_view().at(origin); - } - - ElementCompute scaled_beta = beta; - if (doScaleC) { - scaled_beta *= scale_C.host_view().at(origin); - } - - // - // Verify - // - - auto ref_tA = [&](){ - if constexpr (IsSparseTestbed) { - cutlass::uncompress( - underlying_testbed.tensor_A_uncompressed.host_ref(), - underlying_testbed.tensor_A.host_ref(), - underlying_testbed.tensor_E.host_ref(), - problem_size.m(), - problem_size.k() - ); - return underlying_testbed.tensor_A_uncompressed.host_ref(); - } - else { - return underlying_testbed.tensor_A.host_ref(); - } - }(); - - // Run reference kernel with ElementOutput of type ElementAccumulator - // so that we can compute the absmax epilogue on data that is of type - // ElementAccumulator (which is what the GEMM we are testing will do). - cutlass::reference::host::GemmComplex< - typename Gemm::ElementA, typename Gemm::LayoutA, - typename Gemm::ElementB, typename Gemm::LayoutB, - typename Gemm::ElementC, typename Gemm::LayoutC, - ElementCompute, ElementAccumulator, ElementAccumulator - >( - problem_size, - scaled_alpha, - ref_tA, - Gemm::kTransformA, - underlying_testbed.tensor_B.host_ref(), - Gemm::kTransformB, - scaled_beta, - underlying_testbed.tensor_C.host_ref(), - tmp_D.host_ref(), - ElementAccumulator(0) - ); - - ElementCompute tmp_abs_max_Aux(0.); - ElementCompute tmp_abs_max_D(0.); - - cutlass::NumericConverter cvt_c_to_compute; - cutlass::NumericConverter cvt_accum_to_compute; - cutlass::NumericConverter cvt_compute_to_absmax; - cutlass::NumericConverter cvt_compute_to_d; - cutlass::NumericConverter cvt_compute_to_aux; - - cutlass::absolute_value_op abs; - cutlass::maximum_with_nan_propogation max; - ActivationFunctor act; - - ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); - - for (int m = 0; m < problem_size.m(); ++m) { - for (int n = 0; n < problem_size.n(); ++n) { - ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); - ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); - ElementCompute aux = intermediate + bias; - ElementCompute d = act(aux); - tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); - tmp_abs_max_D = max(abs(d), tmp_abs_max_D); - reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); - - if (kScaleAux) { - reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); - } - } - } - - if (kScaleAux) { - reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); - } - - if (kScaleOutput) { - reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); - } - - return compare_reference(problem_size, alpha, beta); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - return underlying_testbed.sufficient(); - } - - /// Executes one test - bool run( - cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, - int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) - { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - this->initialize(problem_size); - - // - // Initialize the GEMM operator - // - - typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; - typename Gemm::EpilogueOutputOp::Params epilogue_params{ - activation_params, - scale_A.device_data(), - scale_B.device_data(), - scale_C.device_data(), - scale_D.device_data(), - scale_Aux.device_data(), - abs_max_Aux.device_data(), - abs_max_D.device_data() - }; - - auto arguments = [&]() { - if constexpr (IsSparseTestbed) { - return typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - batch_count, - epilogue_params, - underlying_testbed.tensor_A.device_data(), - underlying_testbed.tensor_B.device_data(), - underlying_testbed.tensor_C.device_data(), - underlying_testbed.tensor_D.device_data(), - underlying_testbed.tensor_E_reordered.device_data(), - tensor_Aux.device_data(), - tensor_Vector.device_data(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - int64_t(), - underlying_testbed.tensor_A.layout().stride(0), - underlying_testbed.tensor_B.layout().stride(0), - underlying_testbed.tensor_C.layout().stride(0), - underlying_testbed.tensor_D.layout().stride(0), - underlying_testbed.tensor_E_reordered.layout().stride(0), - tensor_Aux.layout().stride(0), - 0 // stride vector - }; - } - else { - return typename Gemm::Arguments{ - mode, - problem_size, - batch_count, - epilogue_params, - underlying_testbed.tensor_A.device_data(), - underlying_testbed.tensor_B.device_data(), - underlying_testbed.tensor_C.device_data(), - underlying_testbed.tensor_D.device_data(), - tensor_Aux.device_data(), - tensor_Vector.device_data(), - problem_size.m() * problem_size.k(), - problem_size.n() * problem_size.k(), - problem_size.m() * problem_size.n(), - problem_size.m() * problem_size.n(), - 0, // stride vector - underlying_testbed.tensor_A.layout().stride(0), - underlying_testbed.tensor_B.layout().stride(0), - underlying_testbed.tensor_C.layout().stride(0), - underlying_testbed.tensor_D.layout().stride(0), - (int64_t)0 // Leading dimension of vector. This must be 0 - }; - } - }(); - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - status = gemm_op.initialize(arguments, workspace.get()); - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - // - // Run the GEMM - // - - status = gemm_op(); - - EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); - - cudaError_t cuda_error = cudaDeviceSynchronize(); - EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); - - // - // Verify - // - - bool passed = this->verify(problem_size, alpha, beta); - - if (!passed) { - std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; - } - - return passed; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Gemm, - typename GemmTestbed, - template class ActivationFunctor = cutlass::epilogue::thread::Identity -> -bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { - - int const kMinimumOperandElementSize = - std::min( - int(cutlass::sizeof_bits::value), - int(cutlass::sizeof_bits::value)); - - int constexpr kAlignmentM = [&]() { - if constexpr (std::is_same_v>) { - // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) - // because of the reordering of operand E - return std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), - kMinimumOperandElementSize); - } - else { - return 128 / kMinimumOperandElementSize; - } - }(); - - int const kAlignmentN = 128 / kMinimumOperandElementSize; - - int M_problems[] = {kAlignmentM, 128 + 32}; - int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; - int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; - double alpha_problems[] = {1.}; - double beta_problems[] = {0.}; - int split_k_slices[] = { - 1, 2 - }; - - bool passed = true; - - for (int M : M_problems) { - for (int N : N_problems) { - for (int K : K_problems) { - for (int split_k : split_k_slices) { - if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { - // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations - // for different splits to global memory in FP8, while the reference kernel will not. This leads - // to mismatches that are difficult to capture without a permissive relative equality check threshold. - continue; - } - - for (double alpha : alpha_problems) { - for (double beta : beta_problems) { - TestbedWithAmax testbed(scaleA, scaleB, scaleC); - - using ElementAccumulator = typename Gemm::ElementAccumulator; - - passed = testbed.run( - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - split_k, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); - - EXPECT_TRUE(passed) - << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; - - if (!passed) { - - return passed; - } - } - } - } - } - } - } - - return passed; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h deleted file mode 100644 index 8e939f9710403a5f5c3fd8c61e34c4e8021ff423..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h +++ /dev/null @@ -1,358 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/core_io.h" -#include "cutlass/numeric_types.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/gemm.h" - -#include "cutlass/gemm/kernel/default_gemv.h" -#include "cutlass/gemm/kernel/gemv_batched_strided.h" - -namespace test { -namespace gemm { -namespace kernel { - -template -void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, - ElementCD_ alpha = ElementCD_(1), - ElementCD_ beta = ElementCD_(0), - bool perf_test = false, - int perf_test_iter = 1) -{ - using ThreadBlockShape = ThreadBlockShape_; - using ThreadShape = ThreadShape_; - using ElementA = ElementAB_; - using LayoutA = LayoutA_; - using ElementB = ElementAB_; - using LayoutB = LayoutB_; - using ElementAccumulator = ElementCD_; - using ElementCD = ElementCD_; - using LayoutCD = LayoutCD_; - - using GemvKernel = cutlass::gemm::kernel::DefaultGemv; - - using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; - using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; - - if (DEBUG) - { - problem_size = cutlass::gemm::BatchedGemmCoord( - problem_size.m(), problem_size.n(), problem_size.k(), 1); - } - - // Create host tensors that will be the backing store for the batches - // Note that no device memory is initially allocated - cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); - cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); - cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); - cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); - - // Reserve memory for the batch of tensors - matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); - matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); - matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); - matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); - - // Fill eatch tensor batch - const int seed = 9876; - for (int b = 0; b < problem_size.batch(); b++) - { - if(DEBUG) - { - cutlass::reference::host::BlockFillSequential( - matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); - cutlass::reference::host::BlockFillSequential( - matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); - } - else - { - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(b*matrix_A.capacity()), - seed + 1660, - 8, - -8, - 0 - ); - - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(b*matrix_B.capacity()), - seed + 1880, - 8, - -8, - 0 - ); - } - - cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); - cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); - } - - matrix_A.sync_device(); - matrix_B.sync_device(); - matrix_C_computed.sync_device(); - - ThreadBlockSwizzle swizzle; - - cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, - ThreadBlockShape::kN, - problem_size.k(), // no split-k - DEBUG ? 1 : THREAD_B }; - - cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); - - #if 0 - printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); - printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); - #endif - - // No split-k - EXPECT_EQ(tiled_size.k(), problem_size.k()); - - dim3 grid = swizzle.get_grid_shape(tiled_shape); - dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); - - // Some sanity checks - EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); - EXPECT_TRUE( block.x <= 1024 ); - EXPECT_TRUE( block.y <= 1024 ); - EXPECT_TRUE( block.z <= 64 ); - - #if 0 - printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); - printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); - #endif - - cudaError_t result; - cudaEvent_t start_event, end_event; - - for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) - { - if (perf_test && iter == 1) - { - result = cudaEventCreate(&start_event); - EXPECT_EQ(result, cudaSuccess); - - result = cudaEventCreate(&end_event); - EXPECT_EQ(result, cudaSuccess); - - result = cudaEventRecord(start_event); - EXPECT_EQ(result, cudaSuccess); - } - - if (beta == ElementCD(0)) - { - if (alpha == ElementCD(1)) - { - cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( - problem_size, - matrix_A.device_ref(), - matrix_A.capacity(), - matrix_B.device_ref(), - matrix_B.capacity(), - matrix_C_computed.device_ref(), - matrix_C_computed.capacity() - ); - } - else - { - cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( - problem_size, - alpha, - matrix_A.device_ref(), - matrix_A.capacity(), - matrix_B.device_ref(), - matrix_B.capacity(), - matrix_C_computed.device_ref(), - matrix_C_computed.capacity() - ); - } - } - else - { - cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( - problem_size, - alpha, - beta, - matrix_A.device_ref(), - matrix_A.capacity(), - matrix_B.device_ref(), - matrix_B.capacity(), - matrix_C_computed.device_ref(), - matrix_C_computed.capacity(), - matrix_C_computed.device_ref(), - matrix_C_computed.capacity() - ); - } - - if (iter == 0) - { - result = cudaGetLastError(); - EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); - } - } - - if (perf_test) - { - result = cudaEventRecord(end_event); - EXPECT_EQ(result, cudaSuccess); - } - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); - - if (perf_test) - { - float ms; - result = cudaEventElapsedTime(&ms, start_event, end_event); - EXPECT_EQ(result, cudaSuccess); - - double flops = (double(problem_size.m()) * - double(problem_size.n()) * - double(problem_size.k()) * - double(problem_size.batch()) * 2); // 2 for MAC - - double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + - sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); - - double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); - - double avg_runtime = double(ms) / perf_test_iter; - double gflops_per_sec = flops / 1.0e6 / avg_runtime; - double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; - double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; - - std::cout << "\n\nProblem size: " - << problem_size.m() - << " x " << problem_size.n() - << " x " << problem_size.k() - << " x " << problem_size.batch() - << std::endl; - - std::cout << " GFLOPs: " << gflops_per_sec << std::endl; - std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; - std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; - } - else - { - matrix_C_computed.sync_host(); - - // Compute the batched gemms - for (int b = 0; b < problem_size.batch(); b++) - { - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - problem_size.mnk(), alpha, - matrix_A.host_ref(b * matrix_A.capacity()), - matrix_B.host_ref(b * matrix_B.capacity()), beta, - matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed.host_view(b * matrix_C_computed.capacity()), - matrix_C_reference.host_view(b * matrix_C_reference.capacity())); - - EXPECT_TRUE(passed) - //<< "A:\n" << matrix_A.host_view() << "\n" - //<< "B:\n" << matrix_B.host_view() << "\n" - << "Batch: " << b << "\n" - << "Reference:\n" - << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) - << "\n" - << "Computed:\n" - << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) - << "\n"; - } - } -} - -template -void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, - ElementCD_ alpha = ElementCD_(1), - ElementCD_ beta = ElementCD_(0), - int iter = 50) -{ - batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); -} - -} // namespace threadblock -} // namespace kernel -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h deleted file mode 100644 index 6e3d6ab079d44345f2f55f4126ba3efc1eba47cb..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h +++ /dev/null @@ -1,232 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level GEMM -*/ - -#pragma once - -#include "cutlass/gemm/thread/mma.h" -#include "cutlass/layout/vector.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace test { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Thread-level matrix multiply-accumulate -template -void kernel( - typename Mma::ElementC *D, - typename Mma::ElementA const *A, - typename Mma::ElementB const *B, - typename Mma::ElementC const *C) { - - auto ptr_D = reinterpret_cast *>(D); - auto ptr_A = reinterpret_cast const *>(A); - auto ptr_B = reinterpret_cast const *>(B); - auto ptr_C = reinterpret_cast const *>(C); - - Mma mma; - - auto a = *ptr_A; - auto b = *ptr_B; - auto c = *ptr_C; - - using Btype = typename Mma::ElementB; - cutlass::Array d; - - mma(d, a, b, c); - - *ptr_D = d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC -> -struct Testbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = cutlass::gemm::thread::Mma< - Shape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC - >; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed() { - - tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); - tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Runs the test - bool run() { - - // - // initialize device memory - // - - cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( - 0, // seed - 10, // max - 0, // min - 0); // bits after decimal - - cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( - tensor_A.host_view(), - tfill_rand_func); - - for (auto i=0; i< Shape::kM; i++) - for (auto j=0; j< Shape::kK; j++) - tfill_rand(cutlass::make_Coord(i,j)); - - cutlass::reference::host::BlockFillSequential( - tensor_B.host_data(), - tensor_B.capacity(), - ElementB(1), - ElementB(2) - ); - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - - // Host side call - kernel( - tensor_D_computed.host_data(), - tensor_A.host_data(), - tensor_B.host_data(), - tensor_C.host_data()); - - // - // Reference implementation - // - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, Shape::kK}, - ElementC(1), - tensor_A.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed) - << "A:\n" << tensor_A.host_view() << "\n\n" - << "B:\n" << tensor_B.host_view() << "\n\n" - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - - - return passed; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h deleted file mode 100644 index 8d34d7992b57cefa0eaf7300a5e1fb49f41a93e2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h +++ /dev/null @@ -1,236 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level GEMM -*/ - -#pragma once - -#include "cutlass/gemm/thread/mma.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace test { -namespace gemm { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Thread-level matrix multiply-accumulate -template -__global__ void kernel( - typename Mma::ElementC *D, - typename Mma::ElementA const *A, - typename Mma::ElementB const *B, - typename Mma::ElementC const *C) { - - auto ptr_D = reinterpret_cast *>(D); - auto ptr_A = reinterpret_cast const *>(A); - auto ptr_B = reinterpret_cast const *>(B); - auto ptr_C = reinterpret_cast const *>(C); - - Mma mma; - - auto a = *ptr_A; - auto b = *ptr_B; - auto c = *ptr_C; - - cutlass::Array d; - - mma(d, a, b, c); - - *ptr_D = d; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC -> -struct Testbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = cutlass::gemm::thread::Mma< - Shape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC - >; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed() { - - tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); - tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Runs the test - bool run() { - - // - // initialize device memory - // - - cutlass::reference::host::BlockFillSequential( - tensor_A.host_data(), - tensor_A.capacity() - ); - - cutlass::reference::host::BlockFillSequential( - tensor_B.host_data(), - tensor_B.capacity(), - ElementB(1), - ElementB(2) - ); - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - - // launch kernel - kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - //tensor_D_reference.fill(tensor_C.host_view()); - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, Shape::kK}, - ElementC(1), - tensor_A.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed) - << "A:\n" << tensor_A.host_view() << "\n\n" - << "B:\n" << tensor_B.host_view() << "\n\n" - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h deleted file mode 100644 index 1f3bc8cf114d7eb2ac00bd19ae92c984558b7228..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h +++ /dev/null @@ -1,435 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/core_io.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/host_reorder.h" -#include "cutlass/util/host_uncompress.h" - -namespace test { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, - typename Mma::LayoutC::Stride::Index ldc, - typename Mma::IteratorE::Params params_E, - typename Mma::IteratorE::TensorRef ref_E) { - // Shared storage needed by threadblock-scoped matrix multiply- - // Dynamic shared memory base pointer - extern __shared__ int GemmSharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - typename Mma::SharedStorage *shared_storage = - reinterpret_cast(GemmSharedStorageBase); - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k() / Mma::kSparse}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k() / Mma::kSparse}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params_A, ref_A.data(), - {problem_size.m(), problem_size.k() / Mma::kSparse}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorB iterator_B(params_B, ref_B.data(), - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - typename Mma::IteratorE iterator_E( - params_E, ref_E.data(), - {problem_size.m(), - problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, - tb_thread_id, tb_offset_E); - - int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); - - // Construct thread-scoped matrix multiply - Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); - - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_id % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_id / Mma::WarpCount::kM)}); - - iterator_C.store(accum); -} - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename MmaCore_> -struct SparseTestbed { - /// Threadblock-level GEMM implementation - using MmaCore = MmaCore_; - using ThreadblockShape = typename MmaCore::Shape; - using WarpShape = typename MmaCore::WarpShape; - using InstructionShape = typename MmaCore::InstructionShape; - using ElementA = typename MmaCore::ElementA; - using LayoutA = typename MmaCore::LayoutA; - using ElementB = typename MmaCore::ElementB; - using LayoutB = typename MmaCore::LayoutB; - using ElementC = typename MmaCore::ElementC; - using LayoutC = typename MmaCore::LayoutC; - using ElementE = typename MmaCore::ElementE; - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using ThreadMapE = typename MmaCore::IteratorThreadMapE; - using AccessTypeA = cutlass::Array; - using AccessTypeB = cutlass::Array; - using AccessTypeE = cutlass::Array; - static int const Stages = MmaCore::kStages; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - MmaCore::kCacheOpA; - static cutlass::arch::CacheOperation::Kind const CacheOpB = - MmaCore::kCacheOpB; - static cutlass::arch::CacheOperation::Kind const CacheOpE = - MmaCore::kCacheOpE; - - static int const Sparse = MmaCore::kSparse; - static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; - static int const MaxID2 = MmaCore::kMaxID2; - - using LayoutE = cutlass::layout::RowMajor; - using ReorderedLayoutE = typename MmaCore::GmemLayoutE; - - static int const ElementsPerElementE = MmaCore::kElementsPerElementE; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define iterators over tiles from the E operand - using IteratorE = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; - - // Define the threadblock-scoped pipelined matrix multiply - using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, - LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, - typename MmaCore::MmaPolicy, Stages>; - - // - // Data members - // - - cutlass::HostTensor matrix_A; - cutlass::HostTensor matrix_A_uncompressed; - cutlass::HostTensor matrix_B; - cutlass::HostTensor matrix_C_computed; - cutlass::HostTensor matrix_C_reference; - cutlass::HostTensor matrix_E; - cutlass::HostTensor matrix_E_reordered; - - cutlass::gemm::GemmCoord problem_size; - float alpha, beta; - - // - // Methods - // - - /// Allocates workspace in device memory - SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) - : problem_size(m, n, k), alpha(alpha_), beta(beta_) { - matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); - matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - matrix_C_computed.reset(cutlass::make_Coord(m, n)); - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); - matrix_E_reordered.reset( - cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - return true; - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { - - // Waive the test - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); - - cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); - - if (init_E == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomSparseMeta( - matrix_E.host_view(), seed, MetaSizeInBits); - } else if (init_E == cutlass::Distribution::Identity) { - uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; - cutlass::reference::host::TensorFill(matrix_E.host_view(), - (ElementE)(content)); - } else { - return false; - } - - cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), - {problem_size.m(), problem_size.n(), - problem_size.k() / Sparse / ElementsPerElementE}); - - matrix_A.sync_device(); - matrix_B.sync_device(); - matrix_C_computed.sync_device(); - matrix_E_reordered.sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - typename IteratorE::Params params_E(matrix_E_reordered.layout()); - - cudaError_t result; - - int smem_size = int(sizeof(typename Mma::SharedStorage)); - if (smem_size >= (48 << 10)) { - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma_sparse, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) { - return true; - } - - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma_sparse, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - return true; - } - } - - test::gemm::threadblock::kernel_multistage_mma_sparse - <<>>( - problem_size, params_A, matrix_A.device_ref(), params_B, - matrix_B.device_ref(), matrix_C_computed.device_data(), - matrix_C_computed.layout().stride(0), params_E, - matrix_E_reordered.device_ref()); - - // - // Check error code - // - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); - - matrix_C_computed.sync_host(); - - cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), - matrix_E.host_ref(), problem_size.m(), - problem_size.k()); - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm(problem_size, ElementC(alpha), - matrix_A_uncompressed.host_view(), matrix_B.host_view(), - ElementC(beta), matrix_C_reference.host_view()); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed.host_view(), matrix_C_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - - std::cout - << __FILE__ << ":" << __LINE__ << " " - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "E:\n" << matrix_E.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed.host_view() << "\n"; - } - - EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); - - return passed; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h deleted file mode 100644 index 5caaf38ace92758bbc86970d8d4ff339d87348ab..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h +++ /dev/null @@ -1,372 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include "../../common/cutlass_unit_test.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/core_io.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/tensor_view_io.h" - -namespace test { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, - typename Mma::LayoutC::Stride::Index ldc) { - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - - // Dynamic shared memory base pointer - extern __shared__ int GemmSharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - typename Mma::SharedStorage *shared_storage = - reinterpret_cast(GemmSharedStorageBase); - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params_A, ref_A.data(), - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorB iterator_B(params_B, ref_B.data(), - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); - - // Construct thread-scoped matrix multiply - Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); - - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_id % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_id / Mma::WarpCount::kM)}); - - iterator_C.store(accum); -} - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename MmaCore_> -struct Testbed { - /// Threadblock-level GEMM implementation - using MmaCore = MmaCore_; - using ThreadblockShape = typename MmaCore::Shape; - using WarpShape = typename MmaCore::WarpShape; - using InstructionShape = typename MmaCore::InstructionShape; - using ElementA = typename MmaCore::ElementA; - using LayoutA = typename MmaCore::LayoutA; - using ElementB = typename MmaCore::ElementB; - using LayoutB = typename MmaCore::LayoutB; - using ElementC = typename MmaCore::ElementC; - using LayoutC = typename MmaCore::LayoutC; - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeA = cutlass::Array; - using AccessTypeB = cutlass::Array; - static int const Stages = MmaCore::kStages; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - MmaCore::kCacheOpA; - static cutlass::arch::CacheOperation::Kind const CacheOpB = - MmaCore::kCacheOpB; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped pipelined matrix multiply - using Mma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, - LayoutC, typename MmaCore::MmaPolicy, Stages>; - - // - // Data members - // - - cutlass::HostTensor matrix_A; - cutlass::HostTensor matrix_B; - cutlass::HostTensor matrix_C_computed; - cutlass::HostTensor matrix_C_reference; - - cutlass::gemm::GemmCoord problem_size; - float alpha, beta; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) - : problem_size(m, n, k), alpha(alpha_), beta(beta_) { - matrix_A.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - matrix_C_computed.reset(cutlass::make_Coord(m, n)); - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - // - // Determine SMEM requirements and waive if not satisfied - // - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - return true; - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); - - cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); - - matrix_A.sync_device(); - matrix_B.sync_device(); - matrix_C_computed.sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - - cudaError_t result; - - int smem_size = int(sizeof(typename Mma::SharedStorage)); - if (smem_size >= (48 << 10)) { - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - if (result != cudaSuccess) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - } - - test::gemm::threadblock::kernel_multistage_mma - <<>>( - problem_size, params_A, matrix_A.device_ref(), params_B, - matrix_B.device_ref(), matrix_C_computed.device_data(), - matrix_C_computed.layout().stride(0)); - - // - // Check error code - // - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); - - matrix_C_computed.sync_host(); - - cutlass::reference::host::Gemm reference_gemm; - - reference_gemm( - problem_size, ElementC(alpha), matrix_A.host_view(), - matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed.host_view(), matrix_C_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cout - << __FILE__ << ":" << __LINE__ << " " - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed.host_view() << "\n"; - } - - EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); - - return passed; - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h deleted file mode 100644 index 4e617d6327594570b1a88a5b28f2ec4d0467b534..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h +++ /dev/null @@ -1,387 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -namespace test { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC **ptr_C, - typename Mma::LayoutC::Stride::Index ldc) { - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - - // Dynamic shared memory base pointer - extern __shared__ int GemmSharedStorageBase[]; - - // Declare pointer to dynamic shared memory. - typename Mma::SharedStorage *shared_storage = - reinterpret_cast(GemmSharedStorageBase); - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params_A, ref_A.data(), - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorB iterator_B(params_B, ref_B.data(), - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); - int lane_id = threadIdx.x; - - int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); - - // Construct thread-scoped matrix multiply - Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); - - int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_idx_mn % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_idx_mn / Mma::WarpCount::kM)}); - - iterator_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename MmaCore_> -struct Testbed { - /// Threadblock-level GEMM implementation - using MmaCore = MmaCore_; - using ThreadblockShape = typename MmaCore::Shape; - using WarpShape = typename MmaCore::WarpShape; - using InstructionShape = typename MmaCore::InstructionShape; - using ElementA = typename MmaCore::ElementA; - using LayoutA = typename MmaCore::LayoutA; - using ElementB = typename MmaCore::ElementB; - using LayoutB = typename MmaCore::LayoutB; - using ElementC = typename MmaCore::ElementC; - using LayoutC = typename MmaCore::LayoutC; - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeA = cutlass::Array; - using AccessTypeB = cutlass::Array; - static int const Stages = MmaCore::kStages; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - MmaCore::kCacheOpA; - static cutlass::arch::CacheOperation::Kind const CacheOpB = - MmaCore::kCacheOpB; - - // Define iterators over tiles from the A operand - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - - // Define iterators over tiles from the B operand - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; - - // Define the threadblock-scoped pipelined matrix multiply - using Mma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, - IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, - typename MmaCore::MmaPolicy, Stages>; - - static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; - - // - // Data members - // - - cutlass::HostTensor matrix_A; - cutlass::HostTensor matrix_B; - cutlass::HostTensor matrix_C_computed[kPartitionsK]; - cutlass::HostTensor matrix_C_reference; - cutlass::HostTensor matrix_C_pointers; - - cutlass::gemm::GemmCoord problem_size; - float alpha, beta; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) - : problem_size(m, n, k), alpha(alpha_), beta(beta_) { - matrix_A.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); - - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - } else { - return false; - } - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); - - cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); - - matrix_A.sync_device(); - matrix_B.sync_device(); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); - - matrix_C_pointers.sync_device(); - - cudaError_t result; - - int smem_size = int(sizeof(typename Mma::SharedStorage)); - if (smem_size >= (48 << 10)) { - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - EXPECT_EQ(result, cudaSuccess) - << " cudaFuncSetAttribute " - "cudaFuncAttributeMaxDynamicSharedMemorySize error: " - << cudaGetErrorString(result); - - result = cudaFuncSetAttribute( - test::gemm::threadblock::kernel_multistage_mma, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); - - EXPECT_EQ(result, cudaSuccess) - << " cudaFuncSetAttribute " - "cudaFuncAttributePreferredSharedMemoryCarveout error: " - << cudaGetErrorString(result); - } - - test::gemm::threadblock::kernel_multistage_mma<<>>( - problem_size, params_A, matrix_A.device_ref(), params_B, - matrix_B.device_ref(), matrix_C_pointers.device_data(), - matrix_C_computed[0].layout().stride(0)); - - // - // Check error code - // - - result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].sync_host(); - - // TODO: this is temporary. it will be removed after slicing can de - // reduction - // - // Reduce matrix_C_computed - // - CUTLASS_PRAGMA_UNROLL - for(int k = 1; k < kPartitionsK; k++) { - CUTLASS_PRAGMA_UNROLL - for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ - CUTLASS_PRAGMA_UNROLL - for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ - matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); - } - } - } - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - problem_size, ElementC(alpha), matrix_A.host_view(), - matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - std::ofstream output("mma_multistage_testbed_errors.txt"); - - output - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed[0].host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h deleted file mode 100644 index 7eb62f9a39fe4472f77446efc591267001758c58..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h +++ /dev/null @@ -1,353 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -namespace test { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, - typename Mma::LayoutC::Stride::Index ldc) { - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - __shared__ typename Mma::SharedStorage shared_storage; - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params_A, ref_A.data(), - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorB iterator_B(params_B, ref_B.data(), - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - int warp_id = threadIdx.y; - int lane_id = threadIdx.x; - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); - - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_id % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_id / Mma::WarpCount::kM)}); - - iterator_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename MmaCore_, - /// Number of stages - int Stages = 2> -struct Testbed { - /// Threadblock-level GEMM implementation - using MmaCore = MmaCore_; - using ThreadblockShape = typename MmaCore::Shape; - using WarpShape = typename MmaCore::WarpShape; - using InstructionShape = typename MmaCore::InstructionShape; - using ElementA = typename MmaCore::ElementA; - using LayoutA = typename MmaCore::LayoutA; - using ElementB = typename MmaCore::ElementB; - using LayoutB = typename MmaCore::LayoutB; - using ElementC = typename MmaCore::ElementC; - using LayoutC = typename MmaCore::LayoutC; - static const int kStages = Stages; - - // Define iterators over tiles from the A operand - static const bool use_idp4a = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value; - - static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; - static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; - - using IteratorA = typename cutlass::platform::conditional< use_idp4a, - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , - - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> - >::type; - - // Define iterators over tiles from the B operand - using IteratorB = typename cutlass::platform::conditional< use_idp4a, - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , - - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> - >::type; - - // Define MmaPipeline Single Stage - using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, - typename MmaCore::MmaPolicy>; - - // Define MmaPipeline Two Stages - using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, - typename MmaCore::MmaPolicy>; - - // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) - using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; - // - // Data members - // - - cutlass::HostTensor matrix_A; - cutlass::HostTensor matrix_B; - cutlass::HostTensor matrix_C_computed; - cutlass::HostTensor matrix_C_reference; - - cutlass::gemm::GemmCoord problem_size; - float alpha, beta; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed(int m, int n, int k, float alpha_, float beta_) - : problem_size(m, n, k), alpha(alpha_), beta(beta_) { - matrix_A.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - matrix_C_computed.reset(cutlass::make_Coord(m, n)); - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - } - - bool sufficient() { - return true; - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - // Waive test if insufficient CUDA device - if (!sufficient()) { - if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { - std::cerr << "Test waived due to insufficient CUDA device." << std::endl; - } - return true; - } - - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); - - cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); - - matrix_A.sync_device(); - matrix_B.sync_device(); - matrix_C_computed.sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - - test::gemm::threadblock::kernel_mma<<>>( - problem_size, params_A, matrix_A.device_ref(), params_B, - matrix_B.device_ref(), matrix_C_computed.device_data(), - matrix_C_computed.layout().stride(0)); - - // - // Check error code - // - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); - - matrix_C_computed.sync_host(); - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - problem_size, ElementC(alpha), matrix_A.host_view(), - matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed.host_view(), matrix_C_reference.host_view()); - - EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); - - if (!passed) { - std::ofstream output("mma_pipelined_testbed_errors.txt"); - - output - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h deleted file mode 100644 index 36e55b2542b2258542336a052cdd14bf4b85f78d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h +++ /dev/null @@ -1,370 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -namespace test { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC **ptr_C, - typename Mma::LayoutC::Stride::Index ldc) { - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - __shared__ typename Mma::SharedStorage shared_storage; - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params_A, ref_A.data(), - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorB iterator_B(params_B, ref_B.data(), - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - int warp_id = threadIdx.y; - int lane_id = threadIdx.x; - - int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); - - - int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_idx_mn % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_idx_mn / Mma::WarpCount::kM)}); - - iterator_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename MmaCore_> -struct Testbed { - /// Threadblock-level GEMM implementation - using MmaCore = MmaCore_; - using ThreadblockShape = typename MmaCore::Shape; - using WarpShape = typename MmaCore::WarpShape; - using InstructionShape = typename MmaCore::InstructionShape; - using ElementA = typename MmaCore::ElementA; - using LayoutA = typename MmaCore::LayoutA; - using ElementB = typename MmaCore::ElementB; - using LayoutB = typename MmaCore::LayoutB; - using ElementC = typename MmaCore::ElementC; - using LayoutC = typename MmaCore::LayoutC; - - // Define iterators over tiles from the A operand - static const bool use_idp4a = cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value; - - static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; - static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; - - using IteratorA = typename cutlass::platform::conditional< use_idp4a, - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , - - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> - >::type; - - // Define iterators over tiles from the B operand - using IteratorB = typename cutlass::platform::conditional< use_idp4a, - cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , - - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> - >::type; - - // Define the threadblock-scoped pipelined matrix multiply - using Mma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, - IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, - typename MmaCore::MmaPolicy>; - - static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; - - // - // Data members - // - - cutlass::HostTensor matrix_A; - cutlass::HostTensor matrix_B; - cutlass::HostTensor matrix_C_computed[kPartitionsK]; - cutlass::HostTensor matrix_C_reference; - cutlass::HostTensor matrix_C_pointers; - - cutlass::gemm::GemmCoord problem_size; - float alpha, beta; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed(int m, int n, int k, float alpha_, float beta_) - : problem_size(m, n, k), alpha(alpha_), beta(beta_) { - matrix_A.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); - - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - } else { - return false; - } - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); - - cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); - - matrix_A.sync_device(); - matrix_B.sync_device(); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); - - matrix_C_pointers.sync_device(); - - test::gemm::threadblock::kernel_mma<<>>( - problem_size, params_A, matrix_A.device_ref(), params_B, - matrix_B.device_ref(), matrix_C_pointers.device_data(), - matrix_C_computed[0].layout().stride(0)); - - // - // Check error code - // - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); - - CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < kPartitionsK; k++) - matrix_C_computed[k].sync_host(); - - // TODO: this is temporary. it will be removed after slicing can de - // reduction - // - // Reduce matrix_C_computed - // - CUTLASS_PRAGMA_UNROLL - for(int k = 1; k < kPartitionsK; k++) { - CUTLASS_PRAGMA_UNROLL - for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ - CUTLASS_PRAGMA_UNROLL - for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ - matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); - } - } - } - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - problem_size, ElementC(alpha), matrix_A.host_view(), - matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); - - EXPECT_TRUE(passed); - - if (!passed) { - std::ofstream output("mma_pipelined_testbed_errors.txt"); - - output - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed[0].host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h deleted file mode 100644 index e5fdc07769726353b33c1a5da65dedfadb4ce1e7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h +++ /dev/null @@ -1,350 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit testbed for kernel-level GEMM -*/ - -#pragma once - -#include - -#include "../../common/cutlass_unit_test.h" - -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/aligned_buffer.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor_planar_complex.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm_planar_complex.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void kernel_mma_planar_complex( - cutlass::gemm::GemmCoord problem_size, - typename Mma::IteratorA::Params params_A, - typename Mma::IteratorA::Element *ptr_A, - int64_t imaginary_stride_A, - typename Mma::IteratorB::Params params_B, - typename Mma::IteratorB::Element *ptr_B, - int64_t imaginary_stride_B, - typename Mma::ElementC *ptr_C, - typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { - - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - __shared__ typename Mma::SharedStorage shared_storage; - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), - 0}; - - cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, - tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), - tb_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; - - // Construct iterators to A operand - typename Mma::IteratorA iterator_A_real(params_A, ptr_A, - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, - {problem_size.m(), problem_size.k()}, - tb_thread_id, tb_offset_A); - - // Construct iterators to B operand - typename Mma::IteratorB iterator_B_real(params_B, ptr_B, - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, - {problem_size.k(), problem_size.n()}, - tb_thread_id, tb_offset_B); - - int warp_id = threadIdx.y; - int lane_id = threadIdx.x; - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); - - typename Mma::FragmentC accum; - - accum.clear(); - - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); - - // Output results - typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); - - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_id % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_id / Mma::WarpCount::kM)}); - - iterator_C.store(accum.real); - - iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Threadblock-level matrix multiply-accumulate - typename Mma_> -struct TestbedPlanarComplex { - - using Mma = Mma_; - using ThreadblockShape = typename Mma::Shape; - using IteratorA = typename Mma::IteratorA; - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using IteratorB = typename Mma::IteratorB; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Mma::ElementC; - using ElementAccumulator = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - using ThreadMapA = typename Mma::IteratorA::ThreadMap; - using ThreadMapB = typename Mma::IteratorB::ThreadMap; - using AccessTypeA = cutlass::Array; - using AccessTypeB = cutlass::Array; - static int const Stages = Mma::kStages; - static cutlass::arch::CacheOperation::Kind const CacheOpA = - Mma::kCacheOpA; - static cutlass::arch::CacheOperation::Kind const CacheOpB = - Mma::kCacheOpB; - - // - // Data members - // - - cutlass::HostTensorPlanarComplex matrix_A; - cutlass::HostTensorPlanarComplex matrix_B; - cutlass::HostTensorPlanarComplex matrix_C_computed; - cutlass::HostTensorPlanarComplex matrix_C_reference; - - cutlass::gemm::GemmCoord problem_size; - - // - // Methods - // - - /// Allocates workspace in device memory - TestbedPlanarComplex(int m, int n, int k) - : problem_size(m, n, k) { - - matrix_A.reset(cutlass::make_Coord(m, k)); - matrix_B.reset(cutlass::make_Coord(k, n)); - matrix_C_computed.reset(cutlass::make_Coord(m, n)); - matrix_C_reference.reset(cutlass::make_Coord(m, n), false); - } - - /// Runs the test - bool run( - dim3 grid, dim3 block, - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_A.host_view(), seed, scope_max, scope_min, 0); - - } else if (init_A == cutlass::Distribution::Sequential) { - - for (int i = 0; i < matrix_A.capacity() * 2; ++i) { - matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); - } - /* - cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), - matrix_A.capacity() * 2); - */ - } else if (init_A == cutlass::Distribution::Identity) { - //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - - - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); - - - } else if (init_B == cutlass::Distribution::Sequential) { - - cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), - matrix_B.capacity() * 2); - - for (int i = 0; i < matrix_B.capacity() * 2; ++i) { - matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); - } - - - } else if (init_B == cutlass::Distribution::Identity) { - - //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); - - } else { - return false; - } - - matrix_A.sync_device(); - matrix_B.sync_device(); - matrix_C_computed.sync_device(); - - typename IteratorA::Params params_A(matrix_A.layout()); - typename IteratorB::Params params_B(matrix_B.layout()); - - test::gemm::threadblock::kernel_mma_planar_complex<<>>( - problem_size, - params_A, - matrix_A.device_data(), - matrix_A.imaginary_stride(), - params_B, - matrix_B.device_data(), - matrix_B.imaginary_stride(), - matrix_C_computed.device_data(), - matrix_C_computed.layout().stride(0), - matrix_C_computed.imaginary_stride() - ); - - - // - // Check error code - // - - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) - << " kernel error: " << cudaGetErrorString(result); - - matrix_C_computed.sync_host(); - - cutlass::reference::host::GemmPlanarComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator - >( - problem_size, - cutlass::complex(ElementAccumulator(1)), - matrix_A.host_ref(), - Mma::kTransformA, - matrix_B.host_ref(), - Mma::kTransformB, - cutlass::complex(ElementAccumulator(0)), - matrix_C_reference.host_ref(), - matrix_C_reference.host_ref() - ); - - bool passed = cutlass::reference::host::TensorEquals( - matrix_C_computed.host_view(), - matrix_C_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - std::ofstream output("mma_pipelined_testbed_errors.txt"); - - output - << "A:\n" << matrix_A.host_view() << "\n" - << "B:\n" << matrix_B.host_view() << "\n" - << "Reference:\n" - << matrix_C_reference.host_view() << "\n" - << "Computed:\n" - << matrix_C_computed.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h deleted file mode 100644 index 921d1abdc40c2040104815cfffb8b2ea32384136..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h +++ /dev/null @@ -1,1543 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level GEMM -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/numeric_types.h" -#include "cutlass/subbyte_reference.h" -#include "cutlass/platform/platform.h" -#include "cutlass/arch/arch.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/distribution.h" -#include "cutlass/util/reference/host/gemm.h" -#include "cutlass/util/reference/host/gemm_complex.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/host_reorder.h" -#include "cutlass/util/host_uncompress.h" - -namespace test { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Test kernel -template -__global__ void kernel( - typename Mma::ElementC *output_C, - typename Mma::ElementA const *input_A, - typename Mma::ElementB const *input_B, - typename Mma::ElementC const *input_C, - int iterations = 1) { - - // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; - - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; - - if (threadIdx.x == 0) { - typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_A.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_A, i) = - cutlass::ReferenceFactory::type>::get(input_A, i); - } - - typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_B.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_B, i) = - cutlass::ReferenceFactory::type>::get(input_B, i); - } - } - - __syncthreads(); - - // - // Construct warp-level matrix product - // - - using FragmentA = typename Mma::FragmentA; - using FragmentB = typename Mma::FragmentB; - using FragmentC = typename Mma::FragmentC; - - typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); - typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); - typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); - - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); - - FragmentA frag_A; - FragmentB frag_B; - - FragmentC accum; - - Mma mma; - - accum.clear(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < ThreadblockShape::kK; - k += Mma::Policy::MmaShape::kK) { - iter_A.load(frag_A); - iter_B.load(frag_B); - - ++iter_A; - ++iter_B; - - mma(accum, frag_A, frag_B, accum); - } - } - - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); - - iter_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Warp-level matrix multiply-accumulate - typename Mma_, - /// Size of threadblock-scoped shape used to store SMEM - typename ThreadblockShape_, - /// The inner product operation performed by GEMM - typename Operator_ = cutlass::arch::OpMultiplyAdd -> -struct Testbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = Mma_; - using ThreadblockShape = ThreadblockShape_; - using Operator = Operator_; - - using Shape = typename Mma::Shape; - using ElementA = typename Mma::ElementA; - using LayoutA = typename Mma::LayoutA; - using ElementB = typename Mma::ElementB; - using LayoutB = typename Mma::LayoutB; - using ElementC = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed() { - - tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); - tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.major == 9) { - // NVIDIA Hopper drops support for several data types - if ( - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8) { - - return false; - } - } - - return true; - } - - - /// Runs the test - bool run( - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - - cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), - tensor_A.capacity(), seed, scope_max, scope_min, 0); - - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), - tensor_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - - cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), - tensor_B.capacity(), seed, scope_max, scope_min, 0); - - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), - tensor_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - - // launch kernel - kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, ThreadblockShape::kK}, - ElementC(1), - tensor_A.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - - cutlass::TensorView tensor_A_physical( - tensor_A.host_data(), - tensor_A.stride()[0], - tensor_A.extent()); - - cutlass::TensorView tensor_B_physical( - tensor_B.host_data(), - tensor_B.stride()[0], - tensor_B.extent()); - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride()[0] - << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride()[0] - << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; - - std::cout - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Warp-level matrix multiply-accumulate - typename Mma_, - /// Size of threadblock-scoped shape used to store SMEM - typename ThreadblockShape_ -> -struct TestbedComplex { - - /// Thread-level matrix multiply-accumulate operator - using Mma = Mma_; - using ThreadblockShape = ThreadblockShape_; - - using Shape = typename Mma::Shape; - using ElementA = typename Mma::ElementA; - using LayoutA = typename Mma::LayoutA; - using ElementB = typename Mma::ElementB; - using LayoutB = typename Mma::LayoutB; - using ElementC = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - TestbedComplex() { - - tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); - tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.major == 9) { - // NVIDIA Hopper drops support for several data types - if ( - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8) { - - return false; - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), - seed, 8, -8, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), - tensor_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), - seed + 16, 8, -8, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), - tensor_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - - // launch kernel - kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - cutlass::reference::host::GemmComplex( - {Shape::kM, Shape::kN, ThreadblockShape::kK}, - ElementC(1), - tensor_A.host_ref(), - Mma::kTransformA, - tensor_B.host_ref(), - Mma::kTransformB, - ElementC(0), - tensor_C.host_ref(), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - - cutlass::TensorView tensor_A_physical( - tensor_A.host_data(), - tensor_A.stride()[0], - tensor_A.extent()); - - cutlass::TensorView tensor_B_physical( - tensor_B.host_data(), - tensor_B.stride()[0], - tensor_B.extent()); - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; - - std::cout - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Test kernel -template -__global__ void kernel_transform( - typename Mma::ElementC *output_C, - typename Mma::ElementA const *input_A, - typename Mma::ElementB const *input_B, - typename Mma::ElementC const *input_C, - int iterations = 1) { - - // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; - - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; - - if (threadIdx.x == 0) { - typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_A.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_A, i) = - cutlass::ReferenceFactory::type>::get(input_A, i); - } - - typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_B.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_B, i) = - cutlass::ReferenceFactory::type>::get(input_B, i); - } - } - - __syncthreads(); - - // - // Construct warp-level matrix product - // - - using FragmentA = typename Mma::FragmentA; - using FragmentB = typename Mma::FragmentB; - using FragmentC = typename Mma::FragmentC; - - using TransformedFragmentA = typename Mma::TransformedFragmentA; - using TransformedFragmentB = typename Mma::TransformedFragmentB; - - typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); - typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); - typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); - - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); - - FragmentA loaded_frag_A; - FragmentB loaded_frag_B; - TransformedFragmentA transformed_frag_A; - TransformedFragmentB transformed_frag_B; - - FragmentC accum; - - Mma mma; - - accum.clear(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < ThreadblockShape::kK; - k += Mma::Policy::MmaShape::kK) { - iter_A.load(loaded_frag_A); - iter_B.load(loaded_frag_B); - - ++iter_A; - ++iter_B; - - mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, - loaded_frag_B); - - mma(accum, transformed_frag_A, transformed_frag_B, accum); - } - } - - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); - - iter_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Warp-level matrix multiply-accumulate - typename Mma_, - /// Size of threadblock-scoped shape used to store SMEM - typename ThreadblockShape_, - /// The innter product operation performed by GEMM - typename Operator_ = cutlass::arch::OpMultiplyAdd -> -struct TransformTestbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = Mma_; - using ThreadblockShape = ThreadblockShape_; - using Operator = Operator_; - - using Shape = typename Mma::Shape; - using ElementA = typename Mma::ElementA; - using LayoutA = typename Mma::LayoutA; - using ElementB = typename Mma::ElementB; - using LayoutB = typename Mma::LayoutB; - using ElementC = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - TransformTestbed() { - - tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); - tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.major == 9) { - // NVIDIA Hopper drops support for several data types - if ( - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8) { - - return false; - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), - tensor_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), - tensor_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - - // launch kernel - kernel_transform<<>>( - tensor_D_computed.device_data(), tensor_A.device_data(), - tensor_B.device_data(), tensor_C.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, ThreadblockShape::kK}, - ElementC(1), - tensor_A.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - - cutlass::TensorView tensor_A_physical( - tensor_A.host_data(), - tensor_A.stride()[0], - tensor_A.extent()); - - cutlass::TensorView tensor_B_physical( - tensor_B.host_data(), - tensor_B.stride()[0], - tensor_B.extent()); - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; - - std::cout - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Warp-level matrix multiply-accumulate - typename Mma_, - /// Size of threadblock-scoped shape used to store SMEM - typename ThreadblockShape_ -> -struct TransformedTestbedComplex { - - /// Thread-level matrix multiply-accumulate operator - using Mma = Mma_; - using ThreadblockShape = ThreadblockShape_; - - using Shape = typename Mma::Shape; - using ElementA = typename Mma::ElementA; - using LayoutA = typename Mma::LayoutA; - using ElementB = typename Mma::ElementB; - using LayoutB = typename Mma::LayoutB; - using ElementC = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - TransformedTestbedComplex() { - - tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); - tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.major == 9) { - // NVIDIA Hopper drops support for several data types - if ( - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8) { - - return false; - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), - seed, 8, -8, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), - tensor_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), - seed + 16, 8, -8, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), - tensor_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - - // launch kernel - kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - cutlass::reference::host::GemmComplex( - {Shape::kM, Shape::kN, ThreadblockShape::kK}, - ElementC(1), - tensor_A.host_ref(), - Mma::kTransformA, - tensor_B.host_ref(), - Mma::kTransformB, - ElementC(0), - tensor_C.host_ref(), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - - cutlass::TensorView tensor_A_physical( - tensor_A.host_data(), - tensor_A.stride()[0], - tensor_A.extent()); - - cutlass::TensorView tensor_B_physical( - tensor_B.host_data(), - tensor_B.stride()[0], - tensor_B.extent()); - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout - << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; - - std::cout - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Test kernel -template -__global__ void sparse_kernel( - typename Mma::ElementC *output_C, - typename Mma::ElementA const *input_A, - typename Mma::ElementB const *input_B, - typename Mma::ElementC const *input_C, - typename Mma::ElementE const *input_E, - int iterations = 1) { - - // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. - __shared__ cutlass::AlignedBuffer - smem_buffer_A; - - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; - - __shared__ cutlass::AlignedBuffer< - typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / - Mma::kSparse / Mma::kElementsPerElementE> - smem_buffer_E; - - __syncthreads(); - - if (threadIdx.x == 0) { - typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_A.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_A, i) = - cutlass::ReferenceFactory::type>::get(input_A, i); - } - - typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_B.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_B, i) = - cutlass::ReferenceFactory::type>::get(input_B, i); - } - - typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); - #pragma unroll 1 - for (size_t i = 0; i < smem_buffer_E.size(); ++i) { - cutlass::ReferenceFactory::get(smem_ptr_E, i) = - cutlass::ReferenceFactory::type>::get(input_E, i); - } - } - - __syncthreads(); - - // - // Construct warp-level matrix product - // - - using FragmentA = typename Mma::FragmentA; - using FragmentB = typename Mma::FragmentB; - using FragmentC = typename Mma::FragmentC; - using FragmentE = typename Mma::FragmentE; - - typename Mma::LayoutA layout_A = Mma::LayoutA::packed( - {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); - typename Mma::LayoutB layout_B = - Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); - typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); - typename Mma::LayoutE layout_E = - Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, - Mma::Shape::kK / Mma::kSparse / - Mma::kElementsPerElementE / Mma::kInterleaved}); - - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); - - typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); - - FragmentA frag_A; - FragmentB frag_B; - - FragmentC accum; - - FragmentE frag_E; - - Mma mma; - - accum.clear(); - - CUTLASS_PRAGMA_NO_UNROLL - for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < ThreadblockShape::kK; - k += Mma::Policy::MmaShape::kK) { - iter_A.load(frag_A); - iter_B.load(frag_B); - iter_E.load(frag_E); - - ++iter_A; - ++iter_B; - ++iter_E; - - mma(accum, frag_A, frag_B, accum, frag_E); - } - } - - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); - - iter_C.store(accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product -template < - /// Warp-level matrix multiply-accumulate - typename Mma_, - /// Size of threadblock-scoped shape used to store SMEM - typename ThreadblockShape_, - /// The innter product operation performed by GEMM - typename Operator_ = cutlass::arch::OpMultiplyAdd -> -struct SparseTestbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = Mma_; - using ThreadblockShape = ThreadblockShape_; - using Operator = Operator_; - - using Shape = typename Mma::Shape; - using ElementA = typename Mma::ElementA; - using LayoutA = typename Mma::LayoutA; - using ElementB = typename Mma::ElementB; - using LayoutB = typename Mma::LayoutB; - using ElementC = typename Mma::ElementC; - using LayoutC = typename Mma::LayoutC; - - static int const Sparse = Mma::kSparse; - static int const MetaSizeInBits = Mma::kMetaSizeInBits; - static int const MaxID2 = Mma::kMaxID2; - static int const Interleaved = Mma::kInterleaved; - - using ElementE = typename Mma::ElementE; - - static int const ElementsPerElementE = Mma::kElementsPerElementE; - - using LayoutE = cutlass::layout::RowMajor; - using ReorderedLayoutE = - cutlass::layout::ColumnMajorInterleaved; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_A_uncompressed; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - cutlass::HostTensor tensor_E; - cutlass::HostTensor tensor_E_reordered; - - // - // Methods - // - - /// Allocates workspace in device memory - SparseTestbed() { - - tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, - ThreadblockShape::kK / Sparse)); - tensor_A_uncompressed.reset( - cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); - tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - tensor_E.reset(cutlass::make_Coord( - Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); - tensor_E_reordered.reset(cutlass::make_Coord( - Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); - } - - /// Returns true if the CUDA device is sufficient to execute the kernel. - bool sufficient() const { - - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.major == 9) { - // NVIDIA Hopper drops support for several data types - if ( - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8 || - cutlass::sizeof_bits::value < 8) { - - return false; - } - } - - return true; - } - - /// Runs the test - bool run( - cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { - - if (!sufficient()) { - return true; - } - - // - // initialize device memory - // - - if (init_A == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), seed, scope_max, scope_min, 0); - } else if (init_A == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), - tensor_A.capacity()); - } else if (init_A == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); - } else { - return false; - } - - if (init_B == cutlass::Distribution::Uniform) { - int scope_max = 8; - int scope_min = -8; - - if (cutlass::sizeof_bits::value == 4) { - scope_max = 2; - scope_min = -2; - } else if (cutlass::sizeof_bits::value == 1) { - scope_max = 2; - scope_min = 0; - } - - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); - } else if (init_B == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), - tensor_B.capacity()); - } else if (init_B == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); - } else { - return false; - } - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - if (init_E == cutlass::Distribution::Uniform) { - uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomSparseMeta( - tensor_E.host_view(), seed, MetaSizeInBits); - } else if (init_E == cutlass::Distribution::Identity) { - uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; - cutlass::reference::host::TensorFill(tensor_E.host_view(), - (ElementE)(content)); - } else { - return false; - } - - cutlass::reorder_meta( - tensor_E_reordered.host_ref(), tensor_E.host_ref(), - {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_E_reordered.sync_device(); - - // launch kernel - sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data(), - tensor_E_reordered.device_data()); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), - tensor_E.host_ref(), Shape::kM, Shape::kK); - - cutlass::reference::host::Gemm - reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, ThreadblockShape::kK}, - ElementC(1), - tensor_A_uncompressed.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - EXPECT_TRUE(passed); - - if (!passed) { - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; - - std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; - std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; - - std::cout - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << "\n"; - } - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h deleted file mode 100644 index 3311e915db892466a9a4c52c82d100c2e1319966..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h +++ /dev/null @@ -1,43 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace nvrtc { - -extern char const *kCutlassHeaders[]; -extern char const *kCutlassHeaderNames[]; -extern size_t const kCutlassHeaderCount; -} // namespace nvrtc -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp deleted file mode 100644 index 55df44379c847034ed38cfab23477331ee4a537c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cute/tensor.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" - - -namespace nvrtc { -namespace thread { - -template< - typename ElementA, typename ElementB, typename ElementC, - typename TileShape, typename ClusterShape, - bool kTransA, bool kTransB, - int RANK_M, int RANK_N, int RANK_K, int RANK_L -> -struct ContractionKernel { - -using ElementScalar = float; -using ElementAccum = float; -using EpilogueThread = cutlass::epilogue::thread::LinearCombination; - -static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; -static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; - -/// Kernel config -typedef int64_t stride_type; -typedef int32_t extent_type; - -static constexpr const stride_type* stride_null = nullptr; -static constexpr const extent_type* extent_null = nullptr; - -template -static constexpr -auto -make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { - static_assert(Rank > 1); - if constexpr (IsMajor) { - return cute::transform(cute::make_seq{}, [&](auto i) { - if constexpr (i == 0) { - return cute::Int<1>{}; - } - else { - return i < n ? t[i] : init_default; - } - }); - } - else { - return cute::make_int_tuple(t, n, init_default); - } -} - -using StrideA = decltype(cute::make_stride( - make_stride_tuple(stride_null, 0, 0), - make_stride_tuple(stride_null, 0, 0), - cute::make_int_tuple(stride_null, 0, 0))); - -using StrideB = decltype(cute::make_stride( - make_stride_tuple(stride_null, 0, 0), - make_stride_tuple(stride_null, 0, 0), - cute::make_int_tuple(stride_null, 0, 0))); - -using StrideC = decltype(cute::make_stride( - cute::make_int_tuple(stride_null, 0, 0), - cute::make_int_tuple(stride_null, 0, 0), - cute::make_int_tuple(stride_null, 0, 0))); - -using ProblemShape = decltype(cute::make_shape( - cute::make_int_tuple(extent_null, 0, 0), - cute::make_int_tuple(extent_null, 0, 0), - cute::make_int_tuple(extent_null, 0, 0), - cute::make_int_tuple(extent_null, 0, 0))); - -using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementA, StrideA, 16 / sizeof(ElementA), - ElementB, StrideB, 16 / sizeof(ElementB), - ElementAccum, - TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecialized ->::CollectiveOp; - -using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; -using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; -using Kernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveOp, - CollectiveEpilogue>; - -}; - -} // namespace nvrtc -} // namespace thread diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h deleted file mode 100644 index 576f55cd868cd64c8c09c055d8b9a956e40c87ae..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h +++ /dev/null @@ -1,76 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level GEMM -*/ - -#pragma once - -#include "cutlass/array.h" - -namespace test { -namespace nvrtc { -namespace kernel { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Thread-level matrix multiply-accumulate -template -__global__ void testbed_kernel( - typename Mma::ElementC *D, - typename Mma::ElementA const *A, - typename Mma::ElementB const *B, - typename Mma::ElementC const *C) { - - auto ptr_D = reinterpret_cast *>(D); - auto ptr_A = reinterpret_cast const *>(A); - auto ptr_B = reinterpret_cast const *>(B); - auto ptr_C = reinterpret_cast const *>(C); - - Mma mma; - - auto a = *ptr_A; - auto b = *ptr_B; - auto c = *ptr_C; - - cutlass::Array d; - - mma(d, a, b, c); - - *ptr_D = d; -} - -} -} -} -} - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h deleted file mode 100644 index c7e6e94691c82b2f343959421c884c8b0b06f9b4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h +++ /dev/null @@ -1,30 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h deleted file mode 100644 index 5ba5432fd568af71e15b20b8cdab1571f303bcdf..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h +++ /dev/null @@ -1,129 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -typedef char int8_t; -typedef unsigned char uint8_t; -typedef short int16_t; -typedef unsigned short uint16_t; -typedef int int32_t; -typedef unsigned int uint32_t; -typedef long long int int64_t; -typedef unsigned long long int uint64_t; - -#if defined __x86_64__ && !defined __ILP32__ -# define __WORDSIZE 64 -#else -# define __WORDSIZE 32 -#endif - - -/* Small types. */ - -/* Signed. */ -typedef signed char int_least8_t; -typedef short int int_least16_t; -typedef int int_least32_t; -#if __WORDSIZE == 64 -typedef long int int_least64_t; -#else -__extension__ -typedef long long int int_least64_t; -#endif - -/* Unsigned. */ -typedef unsigned char uint_least8_t; -typedef unsigned short int uint_least16_t; -typedef unsigned int uint_least32_t; -#if __WORDSIZE == 64 -typedef unsigned long int uint_least64_t; -#else -__extension__ -typedef unsigned long long int uint_least64_t; -#endif - - -/* Fast types. */ - -/* Signed. */ -typedef signed char int_fast8_t; -#if __WORDSIZE == 64 -typedef long int int_fast16_t; -typedef long int int_fast32_t; -typedef long int int_fast64_t; -#else -typedef int int_fast16_t; -typedef int int_fast32_t; -__extension__ -typedef long long int int_fast64_t; -#endif - -/* Unsigned. */ -typedef unsigned char uint_fast8_t; -#if __WORDSIZE == 64 -typedef unsigned long int uint_fast16_t; -typedef unsigned long int uint_fast32_t; -typedef unsigned long int uint_fast64_t; -#else -typedef unsigned int uint_fast16_t; -typedef unsigned int uint_fast32_t; -__extension__ -typedef unsigned long long int uint_fast64_t; -#endif - -/* Types for `void *' pointers. */ -#if __WORDSIZE == 64 -# ifndef __intptr_t_defined -typedef long int intptr_t; -# define __intptr_t_defined -# endif -typedef unsigned long int uintptr_t; -#else -# ifndef __intptr_t_defined -typedef int intptr_t; -# define __intptr_t_defined -# endif -typedef unsigned int uintptr_t; -#endif - - -/* Largest integral types. */ -#if __WORDSIZE == 64 -typedef long int intmax_t; -typedef unsigned long int uintmax_t; -#else -__extension__ -typedef long long int intmax_t; -__extension__ -typedef unsigned long long int uintmax_t; -#endif - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h deleted file mode 100644 index 8fd6863e8fa003d3fbc4e0b498e3b9b454ade190..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h +++ /dev/null @@ -1,398 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level GEMM -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/gemm/thread/mma.h" -#include "../kernel/thread/testbed_kernel.h" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/trace.h" - -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/gemm.h" - -#include -#include -#include "../cutlass/nvrtc/environment.h" -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { -namespace nvrtc { -namespace thread { - -#define NVRTC_RETURN_IF_ERROR(api) \ - do { \ - nvrtcResult _result = api; \ - if (_result != NVRTC_SUCCESS) { \ - CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ - return false; \ - } \ - } while(0) - -inline const char * cuda_source_fmt = R"""( - -#include "kernel/thread/contraction.hpp" - -using Operator = %s; - -extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { - extern __shared__ char smem[]; - - Operator op; - op(params, smem); -} - -)"""; - -struct TestbedKernel { - static bool compile(std::string const &kernel, std::vector const &opts) { - int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); - std::vector cuda_source(sz + 1); - std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); - - nvrtcProgram program; - NVRTC_RETURN_IF_ERROR( - nvrtcCreateProgram( - &program, - cuda_source.data(), - nullptr, - static_cast(cutlass::nvrtc::kCutlassHeaderCount), - cutlass::nvrtc::kCutlassHeaders, - cutlass::nvrtc::kCutlassHeaderNames) - ); - - nvrtcResult compile_result = - nvrtcCompileProgram( - program, - static_cast(opts.size()), - opts.data()); - - size_t log_size; - NVRTC_RETURN_IF_ERROR( - nvrtcGetProgramLogSize(program, &log_size) - ); - - if (log_size > 1) { - auto log = std::make_unique(log_size); - - NVRTC_RETURN_IF_ERROR( - nvrtcGetProgramLog(program, log.get()) - ); - - std::cout << log.get() << std::endl; - } - - NVRTC_RETURN_IF_ERROR(compile_result); - - NVRTC_RETURN_IF_ERROR( - nvrtcDestroyProgram(&program) - ); - - return true; - } -}; - -/// Structure to compute the matrix product -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape, - /// Data type of A elements - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC -> -struct Testbed { - - /// Thread-level matrix multiply-accumulate operator - using Mma = cutlass::gemm::thread::Mma< - Shape, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC - >; - - // - // Data members - // - - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed() { - - tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); - tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); - tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); - tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); - } - - static inline bool check_nvrtc_error(nvrtcResult error) { - if (error != NVRTC_SUCCESS) { - std::cerr << "failed to compile "; - return false; - } - return true; - } - - /// Runs the test - bool run(std::string const &gemm_traits) { - - // - // initialize device memory - // - - cutlass::reference::host::BlockFillSequential( - tensor_A.host_data(), - tensor_A.capacity() - ); - - cutlass::reference::host::BlockFillSequential( - tensor_B.host_data(), - tensor_B.capacity(), - ElementB(1), - ElementB(2) - ); - - cutlass::reference::host::TensorFill( - tensor_C.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_computed.host_view(), - ElementC(0) - ); - - cutlass::reference::host::TensorFill( - tensor_D_reference.host_view(), - ElementC(0) - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_D_computed.sync_device(); - -#if 0 - // launch kernel - cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( - tensor_D_computed.device_data(), - tensor_A.device_data(), - tensor_B.device_data(), - tensor_C.device_data()); - -#else - // Instantiate gemm_kernel - nvrtcResult result_nvrtc; - nvrtcProgram program; - static char const *src = - "#include \"cutlass/gemm/thread/mma.h\"\n" - "#include \"cutlass/gemm/gemm.h\"\n" - "#include \"cutlass/layout/matrix.h\"\n" - "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" - ; - - std::string type_name; -#if 0 - // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names - // As altername solution we might want to implement to_string() to get the traits string. - nvrtcGetTypeName(&type_name); -#else - type_name = gemm_traits; -#endif - - result_nvrtc = nvrtcCreateProgram(&program, - src, - NULL, - (int)cutlass::nvrtc::kCutlassHeaderCount, - cutlass::nvrtc::kCutlassHeaders, - cutlass::nvrtc::kCutlassHeaderNames); - check_nvrtc_error(result_nvrtc); - - std::string gemm_kernel_instantiation = - "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; - nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); - - const char *opts[] = {"--gpu-architecture=compute_75", - "--std=c++17", - "--include-path=/usr/local/cuda-10.1/include"}; - - result_nvrtc = nvrtcCompileProgram(program, 3, opts); - if (result_nvrtc != NVRTC_SUCCESS) { - size_t logSize; - nvrtcGetProgramLogSize(program, &logSize); - std::vector log(logSize); - nvrtcGetProgramLog(program, log.data()); - std::cout << "Compile log:" << std::endl << log.data() << std::endl; - } - if (!check_nvrtc_error(result_nvrtc)) { - assert(0); - } - - // The lowered name is the name of the template instantiation in the generated PTX code. - char const *gemm_kernel_lowered_name; - nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); - if (!check_nvrtc_error(result_nvrtc)) { - assert(0); - } - - // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards - size_t ptx_size; - result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); - if (!check_nvrtc_error(result_nvrtc)) { - assert(0); - } - - std::vector ptx(ptx_size); - result_nvrtc = nvrtcGetPTX(program, ptx.data()); - if (!check_nvrtc_error(result_nvrtc)) { - assert(0); - } - - // we do not need the nvrtc program anymore - //nvrtcDestroyProgram(&program); - - CUmodule module; - CUresult result_cuda; - result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); - if (result_cuda != CUDA_SUCCESS) { - assert(0); - } - - CUfunction kernel; - result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); - if (result_cuda != CUDA_SUCCESS) { - assert(0); - } - - void* d_a = (void*)tensor_A.device_data(); - void* d_b = (void*)tensor_B.device_data(); - void* d_c = (void*)tensor_C.device_data(); - void* d_d = (void*)tensor_D_computed.device_data(); - void* args[] = { &d_d, &d_a, &d_b, &d_c }; - - // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra - result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); - if (result_cuda != CUDA_SUCCESS) { - assert(0); - } else { -} -#endif - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - if (result != cudaSuccess) { - std::cout << "CUDA ERROR: " << cudaGetErrorString(result); - return false; - } - - tensor_D_computed.sync_host(); - - // - // Reference implementation - // - - //tensor_D_reference.fill(tensor_C.host_view()); - - cutlass::reference::host::Gemm reference_gemm; - - reference_gemm( - {Shape::kM, Shape::kN, Shape::kK}, - ElementC(1), - tensor_A.host_ref(), - tensor_B.host_ref(), - ElementC(0), - tensor_D_reference.host_ref() - ); - - // - // Verify equivalence - // - - // compare - bool passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view() - ); - - if(!passed) std::cout - << "A:\n" << tensor_A.host_view() << "\n\n" - << "B:\n" << tensor_B.host_view() << "\n\n" - << "C:\n" << tensor_C.host_view() << "\n\n" - << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" - << "Computed:\n" << tensor_D_computed.host_view() << std::endl; - - std::cout << "passed " << passed << std::endl; - - return passed; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace nvrtc -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h deleted file mode 100644 index 6cc2946a2c51cfb8c1971345c81c1910bd667208..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h +++ /dev/null @@ -1,145 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Common Testbed file shared by Pipeline unit tests -*/ - -#include -#include -#include -#include - -#include "cutlass/util/command_line.h" -#include "../common/cutlass_unit_test.h" - -#if CUDA_12_0_SM90_FEATURES_SUPPORTED - #define CUTLASS_UNIT_TEST_PIPELINE true -#else - #define CUTLASS_UNIT_TEST_PIPELINE false -#endif - -// Command line test options -struct Options { - // - // Data Members - // - bool help; - bool verification_enabled; - int SM_count; - int clock_MHz; - - // - // Methods - // - Options(): - help(false), - verification_enabled(true), - SM_count(116), - clock_MHz(1477) - { } - - void parse(int argc, char const **args) { - cutlass::CommandLine cmd(argc, args); - - if (cmd.check_cmd_line_flag("help")) { - help = true; - } - - cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); - cmd.get_cmd_line_argument("sm-count", SM_count, 116); - cmd.get_cmd_line_argument("clock", clock_MHz, 1477); - } - - /// Prints the usage statement. - std::ostream & print_usage(std::ostream &out) const { - - out << "Options:\n\n" - << " --help If specified, displays this usage statement.\n\n" - << " --verification-enabled= Enable/Disable verification\n" - << " --sm-count= Number of SMs on the chip\n" - << " --clock= Locked clock value in Mhz\n"; - - return out; - } -}; - -// -// Testbed -// - -template -struct Testbed { -private: - // Commandline options - Options options; - - void run_test(uint32_t const kNumIters) { - - // Run CuTe Gemm - Pipeline pipeline; - - cudaError_t result = pipeline.run(kNumIters); - - CUTE_CHECK_LAST(); - } - - -public: - Testbed(Options const &options_) : options(options_) { - int device_id = 0; - cudaDeviceProp device_prop; - CUTE_CHECK_ERROR(cudaSetDevice(device_id)); - CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); - - if (device_prop.major < 1) { - fprintf(stderr, "Device does not support CUDA.\n"); - exit(1); - } - } - - /// Run verification Gemm problem sizes - bool verification() { - - std::array kNumIters; - - for (size_t i = 0; i < kNumIters.size(); ++i) { - kNumIters[i] = static_cast( (rand() % 1000) + 1 ); - } - - for (int n : kNumIters) { - std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; - run_test(n); - } - - return true; - } -}; diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h deleted file mode 100644 index 50a68a1437956c95aa4e7912e93adc8b1481c9cc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h +++ /dev/null @@ -1,154 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Testbed file used by cluster launch control pipeline unit test -*/ - -// - -// - -#if CUDA_12_0_SM90_FEATURES_SUPPORTED - #define CUTLASS_UNIT_TEST_PIPELINE true -#else - #define CUTLASS_UNIT_TEST_PIPELINE false -#endif - -#include -#include -#include -#include - -#include "cutlass/util/command_line.h" - -// Command line test options -struct OptionsClusterLaunch { - // - // Data Members - // - bool help = false; - bool verification_enabled = true; - int SM_count = 116; - int clock_MHz = 1477; - dim3 grid_dim = {0,0,0}; - - // - // Methods - // - - void parse(int argc, char const **args) { - cutlass::CommandLine cmd(argc, args); - - if (cmd.check_cmd_line_flag("help")) { - help = true; - } - - cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); - cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); - cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); - } - - /// Prints the usage statement. - std::ostream & print_usage(std::ostream &out) const { - - out << "Options:\n\n" - << " --help If specified, displays this usage statement.\n\n" - << " --verification-enabled= Enable/Disable verification\n" - << " --sm-count= Number of SMs on the chip\n" - << " --clock= Locked clock value in Mhz\n"; - - return out; - } -}; - -// -// Testbed -// - -template -class TestbedClusterLaunch { -private: - // Commandline options - OptionsClusterLaunch options; - - bool run_test() { - - // Run CuTe Gemm - Pipeline pipeline; - - bool success = false; - cudaError_t result = pipeline.run(success, this->options.grid_dim); - - CUTE_CHECK_LAST(); - return success; - } - - -public: - TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { - int device_id = 0; - cudaDeviceProp device_prop; - CUTE_CHECK_ERROR(cudaSetDevice(device_id)); - CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); - - if (device_prop.major < 1) { - fprintf(stderr, "Device does not support CUDA.\n"); - exit(1); - } - } - - /// Run verification Gemm problem sizes - bool verification() { - -#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - printf( - "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" - "This test is waived.\n" - ); - return true; -#endif - -#if 0 - bool is_success = false; - for (int i = 0; i< 10; i++){ - printf("iteration = %d\n", i); - is_success = run_test(); - if ( not is_success ) - return is_success; - } - return is_success; -#else - // Run the test with single launch - return run_test(); -#endif - } -}; diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h deleted file mode 100644 index e44a42463ae95e4f76388d791c661de875092c93..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h +++ /dev/null @@ -1,45 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level Reduction -*/ - -#pragma once - -#include "cutlass/reduction/thread/reduce.h" - -#include "cutlass/layout/vector.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h deleted file mode 100644 index 239f228831a25527106af1659383112535943df1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h +++ /dev/null @@ -1,242 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Unit tests for thread-level Reduction -*/ - -#pragma once - -#include "cutlass/reduction/thread/reduce.h" - -#include "cutlass/layout/vector.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" - -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_compare.h" - -namespace test { -namespace reduction { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the reduction -template < - /// Data type of elements - typename Element, - /// Number of elements - int N -> -struct Testbed_reduce_host { - - /// Thread-level reduction operator - using Reduce = cutlass::reduction::thread::Reduce< - cutlass::plus, - cutlass::Array - >; - - // - // Data members - // - - cutlass::Array tensor_in; - cutlass::Array reduced_tensor_computed; - cutlass::Array reduced_tensor_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed_reduce_host() { - tensor_in.clear(); - reduced_tensor_computed.clear(); - reduced_tensor_reference.clear(); - } - - /// Runs the test - bool run() { - - // - // initialize memory - // - - for(int i = 0; i < N; i++) - tensor_in.at(i) = Element(i); - - - Reduce reduce; - - cutlass::Array *out_ptr = &reduced_tensor_computed; - out_ptr[0] = reduce(tensor_in); - - // - // Reference implementation - // - Element e(0); - for (int i = 0; i < N; i++) - e = e + Element(i); - - reduced_tensor_reference.at(0) = e; - - // - // Verify equivalence - // - - // compare - bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; - - EXPECT_TRUE(passed) - << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" - << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" - << std::endl; - - return passed; - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Thread-level reduction kernel -template -__global__ void kernel_reduce(Element const *array_in, Element *result) { - - /// Thread-level reduction operator - using Reduce = cutlass::reduction::thread::Reduce< - cutlass::plus, - cutlass::Array - >; - - Reduce reduce; - - auto ptr_in = reinterpret_cast const *>(array_in); - auto result_ptr = reinterpret_cast *>(result); - auto in = *ptr_in; - result_ptr[0] = reduce(in); -} - - -/// Structure to compute the reduction -template < - /// Data type of elements - typename Element, - /// Number of elements - int N -> -struct Testbed_reduce_device { - - using Layout = cutlass::layout::PackedVectorLayout; - - // - // Data members - // - - cutlass::HostTensor tensor_in; - cutlass::HostTensor reduced_tensor_computed; - cutlass::HostTensor reduced_tensor_reference; - - // - // Methods - // - - /// Allocates workspace in device memory - Testbed_reduce_device() { - - tensor_in.reset(cutlass::make_Coord(N), true); - reduced_tensor_computed.reset(cutlass::make_Coord(1), true); - reduced_tensor_reference.reset(cutlass::make_Coord(1), true); - } - - - /// Runs the test - bool run() { - - // - // initialize memory - // - - cutlass::reference::host::TensorFill( - tensor_in.host_view(), - Element(1) - ); - - cutlass::reference::host::TensorFill( - reduced_tensor_computed.host_view(), - Element(0) - ); - - cutlass::reference::host::TensorFill( - reduced_tensor_reference.host_view(), - Element(N) - ); - - tensor_in.sync_device(); - reduced_tensor_computed.sync_device(); - reduced_tensor_reference.sync_device(); - - /// call the kernel - kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( - tensor_in.device_data(), - reduced_tensor_computed.device_data() - ); - - // verify no errors - cudaError_t result = cudaDeviceSynchronize(); - - EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); - if (result != cudaSuccess) { - return false; - } - - // Copy back results - reduced_tensor_computed.sync_host(); - - // Verify equivalence - bool passed = cutlass::reference::host::TensorEquals( - reduced_tensor_computed.host_view(), - reduced_tensor_reference.host_view() - ); - - EXPECT_TRUE(passed) - << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" - << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" - << std::endl; - - return passed; - } -}; - -} // namespace thread -} // namespace reduction -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp deleted file mode 100644 index c4e7de4351076dba3a699b4cb1c8a6e01485bc20..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp +++ /dev/null @@ -1,481 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Compress utils specific for SM90 structure sparse kernels -*/ - -#pragma once - -#include // std::fill -#include // std::array -#include -#include // std::mt19937 - -#include "cute/container/bit_field.hpp" // cute::bit_field -#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v -#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor -#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 -#include "cutlass/cutlass.h" // cutlass::Status -#include "cutlass/detail/collective.hpp" -#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t -#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up -#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo -#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride -#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes -#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter - -namespace cutlass -{ -namespace transform -{ -namespace kernel -{ - -using namespace cute; - -namespace detail { - - template - CUTLASS_HOST_DEVICE - static uint8_t - encode_in_chunk_idx_legacy(int in_chunk_idx){ - if (sizeof(T) == 4) { - return in_chunk_idx == 0 ? 0b0100 : 0b1110; - } - else { - uint8_t res = 0; - if (in_chunk_idx == 0) { - res = 0b00; - } - else if (in_chunk_idx == 1) { - res = 0b01; - } - else if (in_chunk_idx == 2) { - res = 0b10; - } - else { - res = 0b11; - } - return res; - } - } - - template < - class SparseConfig, - class EngineA, - class LayoutA, - class EngineAc, - class LayoutAc - > - CUTLASS_HOST_DEVICE - static void - compress_two_chunks_legacy( - Tensor tensorA, - Tensor tensorAc, - uint8_t& meta_two_chunk, - int effective_elems) { - - using ElementA = typename EngineAc::value_type; - - static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; - static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; - static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; - static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{}; - static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - - /* - Legal metadata chunk in SM90 - Index Bin HEX - 0, 1 0b0100 4 - 1, 2 0b1001 9 - 2, 3 0b1110 E - 0, 2 0b1000 8 - 1, 3 0b1101 D - 0, 3 0b1100 C - 2, 1 0b0110 6 (Not used) - ----------------------------------- - TF32 - 0 0b0100 4 - 1 0b1110 E - */ - - if (effective_elems <= 0) { - return; - } - - // initialize - // 0 is the initial value for this function while 0x44 is the initial value for hardware. - meta_two_chunk = 0; - - for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) { - // If Only One Chunk within this Two Chunk - if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) { - break; - } - /// init result; - int non_zero_cnt = 0; - int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 }; - ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} }; - - for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) { - bool is_nz = true; - ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; - /// Check if subchunk is non-zero - for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) { - int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx; - subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0); - - ElementA zero = static_cast(0); - ElementA minus_zero = static_cast(ElementA(1) << cutlass::sizeof_bits_v - 1); - if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) { - if (non_zero_cnt >= PhysicalSubChunk) { - #ifdef __CUDA_ARCH__ - asm volatile ("brkpt;\n" ::); - #else - throw std::runtime_error("Found extra non-zero elements in a chunk!\n"); - #endif - } - is_nz = false; - } - } - - /// There is non-zero element in the subchunk - if(!is_nz) { - nnz_chunk_idx[non_zero_cnt] = subchunk_idx; - memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); - non_zero_cnt++; - } - } - - /* - Special cases - nnz == 1 and non-tf32 and nnz_idx = 3 - */ - ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; - if constexpr (sizeof_bits_v < 32) { - if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) { - memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw); - memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); - nnz_chunk_idx[1] = 3; - nnz_chunk_idx[0] = 0; - } - else if (non_zero_cnt == 1) { - memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); - nnz_chunk_idx[1] = 3; - } - } - - /// Setup metadata - uint8_t meta_chunk = 0; - for (int i = 0; i < PhysicalSubChunk; i++) { - meta_chunk = static_cast(meta_chunk | (encode_in_chunk_idx_legacy(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma))); - for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) { - tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j]; - } - } - meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{}))); - } - } -} - -template< - class ProblemShape_, - class ElementA_, - class LayoutATag_, - class SparseConfig_ -> -class SM90StructuredSparseCompressorLegacy { -public: - using SparseConfig = SparseConfig_; - using ProblemShape = ProblemShape_; - - // * EltA - using ElementA = ElementA_; - using ElementAUint = cute::uint_bit_t>; - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - using ArrayElementA = cute::conditional_t>, - ElementA>; - using ElementAMma = typename SparseConfig::ElementAMma; - using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; - using ElementASparsity = typename SparseConfig::ElementASparsity; - using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; - using LayoutATag = LayoutATag_; - using LayoutA = LayoutATag; - using StrideA = cutlass::gemm::TagToStrideA_t; - - // * EltE - using ElementEMma = typename SparseConfig::ElementEMma; - using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; - using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; - - // * AtomE - using TensorEAtom = typename SparseConfig::TensorEAtom; - using TensorEAtomK = typename SparseConfig::TensorEAtomK; - using TensorEAtomM = typename SparseConfig::TensorEAtomM; - - static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; - static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; - static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; - static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); - - // * Alignment - static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; - static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; - static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; - static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; - - // Required by `device_kernel` - static constexpr int MaxThreadsPerBlock = 1; - static constexpr int MinBlocksPerMultiprocessor = 1; - using ArchTag = arch::Sm90; - - struct SharedStorage { - /* empty, no smem needed */ - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - struct TransformArguments { - ArrayElementA const* ptr_A{nullptr}; - StrideA dA{}; - ArrayElementA* ptr_ACompress{nullptr}; - ElementEMmaRaw* ptr_E{nullptr}; - }; - - using TransformParams = TransformArguments; - - struct Arguments { - ProblemShape problem_shape{}; - TransformArguments transform{}; - KernelHardwareInfo hw_info{}; - }; - - struct Params { - ProblemShape problem_shape{}; - TransformParams transform{}; - KernelHardwareInfo hw_info{}; - void* workspace = nullptr; - }; - - static Params - to_underlying_arguments(Arguments & args, void* workspace) { - return Params{{args.problem_shape}, - {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, - {args.hw_info}, - workspace}; - } - - static Status - can_implement(Arguments const& args) { - auto [M, N, K, L] = args.problem_shape; - if (K % LogicalElemsAPerChunk != 0) { - CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n"); - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - - static size_t - get_workspace_size(Arguments const& args) { - auto problem = args.problem_shape; - const int m = cute::size<0>(problem); - const int k = cute::size<2>(problem); - const int l = cute::size<3>(problem); - const int metadata_k = round_up(k, TensorEAlignmentK); - const int metadata_m = round_up(m, TensorEAlignmentM); - const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l; - return metadata_bytes; - } - - static Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - cudaError_t cuda_error; - - auto workspace_size = get_workspace_size(args); - if (workspace_size == 0) { - return Status::kSuccess; - } else if (workspace == nullptr) { - return Status::kErrorInternal; - } - - cudaPointerAttributes attri; - cuda_error = cudaPointerGetAttributes(&attri, workspace); - if (cuda_error != cudaSuccess) { - return Status::kErrorInternal; - } - - if ( attri.type == cudaMemoryTypeDevice ) { -#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER - CUTLASS_ASSERT(cuda_adapter); - if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { - return Status::kErrorInternal; - } -#else - cudaMemsetAsync(workspace, 0, workspace_size, stream); - cuda_error = cudaGetLastError(); - if (cuda_error != cudaSuccess) { - return Status::kErrorInternal; - } -#endif - } else { - memset(workspace, 0, workspace_size); - } - - return Status::kSuccess; - } - - static dim3 - get_grid_shape(Params const& params) { - return dim3(1, 1, 1); - } - - static dim3 - get_block_shape() { - return dim3(1, 1, 1); - } - - CUTE_HOST_DEVICE - void - operator()(Params params, char* smem_buf = nullptr) { - run(params, smem_buf); - } - - CUTE_HOST_DEVICE - static void - run(Params params, char* smem_buf = nullptr) { - do_compress_device_host(params); - } - -private: - - CUTE_HOST_DEVICE - static void - do_compress_device_host(Params params) { - auto [m, n, k, l] = params.problem_shape; - auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; - auto workspace = params.workspace; - - const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK; - const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM; - const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK; - const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM; - const int k_compressed = aligned_k / ElementASparsity{}; - - // Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here. - cute::Tensor tensorA = make_tensor(recast_ptr(ptr_A), make_layout(make_shape(m, k, l), dA)); - - cute::Tensor tensorAc = make_tensor(recast_ptr(ptr_ACompress), - make_shape(aligned_m, k_compressed, l), - make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l))); - - cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr>(workspace), - make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l), - make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k)); - - cute::Tensor tensorE_raw_compress = recast(tensorE_raw_compress_logical); - - // The following vars are all logical. - int atom_m = size<0>(TensorEAtom{}); - int atom_k = size<1>(TensorEAtom{}); - int tiled_m = metadata_m / atom_m; - int tiled_ke = metadata_k / atom_k; - // Col major when viewing atoms - int stride_tile_m = cosize(TensorEAtom{}); - int stride_tile_ke = atom_k * metadata_m; - - // Logical metadata tensor - cute::Tensor tensorE_logical = make_tensor(recast_ptr>(ptr_E), - make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m), - append(shape<1>(TensorEAtom{}), tiled_ke), - shape<2>(tensorE_raw_compress_logical)), - make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m), - append(stride<1>(TensorEAtom{}), stride_tile_ke), - stride<2>(tensorE_raw_compress_logical)))); - // Physical metadata tensor - cute::Tensor tensorE = recast(tensorE_logical); - - // void do_init() - cute::clear(tensorAc); - cute::clear(tensorE_raw_compress); - - // void do_raw_compress() - using TileStepA = Int; - using TileStepAc = Int; - - cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _)); - cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _)); - - for (int batch_idx = 0; batch_idx < l; batch_idx++) { - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) { - int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{})); - detail::compress_two_chunks_legacy(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), - tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), - tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx), - effective_elems); - } - } - } - - // void do_reorder() - // Fast path when we don't permute. - if constexpr (sizeof_bits_v <= 8) { - memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size()); - } - else { - cute::copy(tensorE_raw_compress, tensorE); - } - - #if 0 - print("--> TensorA\n"); - auto tensorA_eltA = cute::recast(tensorA); - cute::print_tensor(tensorA_eltA); printf("\n\n"); - - print("--> REF TensorAC\n"); - auto tensorAc_eltA = cute::recast(tensorAc); - cute::print_tensor(tensorAc_eltA); printf("\n\n"); - - print("--> REF TensorE\n"); - cute::print_tensor(tensorE); printf("\n\n"); - #endif - - } -}; - -} // namespace kernel -} // namespace transform -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp deleted file mode 100644 index f44458244e0d3c4c80ecc29a0115cd6906211559..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp +++ /dev/null @@ -1,877 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* - * @brief Test for structured sparse gemm compressor device kernel - */ - -#pragma once - -#include // cudaGetLastError - -#include // uint64_t -#include // printf -#include // malloc -#include // std::cout -#include -#include - -#include "cute/layout.hpp" // cute::make_shape -#include "cute/util/type_traits.hpp" // cute::is_same_v -#include "cutlass/coord.h" // cutlass::make_Coord -#include "cutlass/cutlass.h" // cutlass::Status -#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo -#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory -#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_ -#include "cutlass/tensor_view.h" // cutlass::TensorView -#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility -#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation -#include "cutlass/util/distribution.h" // cutlass::Distribution -#include "cutlass/util/host_tensor.h" // cutlass::HostTensor -#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride -#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals -#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill -#include "cutlass/detail/collective.hpp" - -#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor -#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE - - -#define CUDA_CHECK_FALSE(cuda_error) \ - { \ - if (cuda_error != cudaSuccess) { \ - printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ - return false; \ - } \ - } - -#define CUDA_CHECK(cuda_error) \ - { \ - if (cuda_error != cudaSuccess) { \ - printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ - return; \ - } \ - } - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// * Test Bed -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test -{ -namespace transform -{ -namespace device -{ - -// Helper Functions -template -bool -initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) -{ - if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } - else if (bits_input <= 8) { - scope_max = 1; - scope_min = -1; - } else { - scope_max = 4; - scope_min = -4; - } - cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); - } - - else if (dist_kind == cutlass::Distribution::Identity) { - cutlass::reference::host::TensorFillIdentity(view); - } - - else if (dist_kind == cutlass::Distribution::Gaussian) { - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - - else if (dist_kind == cutlass::Distribution::AllOnes) { - cutlass::reference::host::TensorFill(view, Element(1)); - } - - else if (dist_kind == cutlass::Distribution::AllZeros) { - cutlass::reference::host::TensorFill(view, Element(0)); - } - - else { - EXPECT_TRUE(false) << "Not implemented"; - return false; - } - - return true; -} - -// Testbed -template -struct TestbedSparseGemmCompressor { -public: - using Compressor = Compressor_; - using CompressorKernel = typename Compressor::TransformKernel; - - using ElementA = typename CompressorKernel::ElementA; - using LayoutATag = typename CompressorKernel::LayoutATag; - using StrideA = typename CompressorKernel::StrideA; - using ArrayElementA = - ElementA - ; - - using ElementE = typename CompressorKernel::ElementEMmaRaw; - using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor - - using SparseConfig = typename CompressorKernel::SparseConfig; - using ProblemShapeType = typename CompressorKernel::ProblemShape; - - using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< - ProblemShapeType, - ElementA, - LayoutATag, - SparseConfig>; - - using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy< - ProblemShapeType, - ElementA, - LayoutATag, - SparseConfig>; - - using CompressorHost = cutlass::transform::device::TransformUniversalAdapter; - - static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk; - static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk; - - struct Data { - // Data Storage - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_A_Comp; - cutlass::HostTensor tensor_E; - cutlass::HostTensor tensor_A_Comp_ref; - cutlass::HostTensor tensor_E_ref; - }; - - struct CudaRAII { - cudaStream_t stream; - cudaEvent_t start; - cudaEvent_t stop; - - CudaRAII(){ - CUDA_CHECK(cudaStreamCreate( &stream )); - CUDA_CHECK(cudaEventCreate( &start )); - CUDA_CHECK(cudaEventCreate( &stop )); - }; - - CudaRAII(const CudaRAII&) = delete; - CudaRAII& operator=(const CudaRAII&) = delete; - CudaRAII(CudaRAII&&) = delete; - CudaRAII& operator=(CudaRAII&&) = delete; - - ~CudaRAII(){ - CUDA_CHECK(cudaStreamDestroy( stream )); - CUDA_CHECK(cudaEventDestroy( start )); - CUDA_CHECK(cudaEventDestroy( stop )); - } - }; - -public: - TestbedSparseGemmCompressor( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform, - uint64_t seed_ = 7) - : init_A(init_A_) - , init_E(init_E_) - , init_A_Comp(init_A_Comp_) - , seed(seed_) - { - } - - bool valid_test(ProblemShapeType problem_shape_MNKL) - { - const int GemmK = cute::size<2>(problem_shape_MNKL); - - if ( GemmK % LogicalElemsAPerChunk != 0 ) { - printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n"); - return false; - } - - return true; - } - - bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas) - { - CUDA_CHECK_FALSE(cudaGetLastError()); - - // In unit of ElementARaw - const int GemmM = cute::size<0>(problem_shape_MNKL); - const int GemmN = cute::size<1>(problem_shape_MNKL); - const int GemmK = cute::size<2>(problem_shape_MNKL); - const int GemmL = cute::size<3>(problem_shape_MNKL); - - // Compressor utility to get allocated data size - auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); - CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); - - // TensorA - // In unit of ElementARaw, after alignment requirement - // M-dim: no alignment requirement - // K-dim: multiplier of chunk size - - // TensorA Compressed - // In unit of ElementARaw, after alignment requirement - // M-dim: TMA alignment - // K-dim: TMA alignment - const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical(); - const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical(); - - // TensorE - // In unit of ElementE (uint8_t), after alignment requirement - // M-dim: TensorEAtom_M alignment - // K-dim: TensorEAtom_K alignment - const int GemmMAlignedE = compressor_utility.get_metadata_m_physical(); - const int GemmKAlignedE = compressor_utility.get_metadata_k_physical(); - - auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK); - auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE); - auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC); - - typename LayoutATag::Stride stride_factor_A; - typename LayoutETag::Stride stride_factor_E; - - datas.tensor_A.resize(a_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); - datas.tensor_A_Comp.resize(a_comp_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); - datas.tensor_A_Comp_ref.resize(a_comp_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A), - false); - datas.tensor_E.resize(e_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); - datas.tensor_E_ref.resize(e_coord, - cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E), - false); - - EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1)); - EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2)); - EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3)); - EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4)); - EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5)); - - compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6); - - // Check for failed devide - CUDA_CHECK_FALSE(cudaGetLastError()); - - datas.tensor_A.sync_device(); - datas.tensor_A_Comp.sync_device(); - datas.tensor_E.sync_device(); - - // Check for failed devide - CUDA_CHECK_FALSE(cudaGetLastError()); - - return true; - } - - bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr) - { - CudaRAII cuda_raii; - - const int GemmM = cute::size<0>(problem_shape_MNKL); - const int GemmN = cute::size<1>(problem_shape_MNKL); - const int GemmK = cute::size<2>(problem_shape_MNKL); - const int GemmL = cute::size<3>(problem_shape_MNKL); - - StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Compressor::Arguments arguments{ - {GemmM, GemmN, GemmK, GemmL}, - {datas.tensor_A.device_data(), - stride_a, - datas.tensor_A_Comp.device_data(), - datas.tensor_E.device_data()}, - {hw_info} - }; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - cutlass::Status status {cutlass::Status::kSuccess }; - - status = compressor_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - CUDA_CHECK_FALSE(cudaGetLastError()); - } - - status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream); - if (status != cutlass::Status::kSuccess) { - CUDA_CHECK_FALSE(cudaGetLastError()); - } - - CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); - CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream)); - - status = compressor_op.run(cuda_raii.stream); - if (status != cutlass::Status::kSuccess) { - CUDA_CHECK_FALSE(cudaGetLastError()); - } - - CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream)); - CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop)); - CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); - if ( time != nullptr ){ - CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop)); - } - - datas.tensor_A_Comp.sync_host(); - datas.tensor_E.sync_host(); - - #if 0 - { - printf("\n--> DEVICE OUTPUT\n"); - printf("datas.tensor_A\n"); - std::cout << datas.tensor_A.host_view() << std::endl << std::endl; - printf("datas.tensor_A_Comp\n"); - std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl; - printf("datas.tensor_E\n"); - std::cout << datas.tensor_E.host_view() << std::endl << std::endl; - } - #endif - - return true; - } - - bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas) - { - const int GemmM = cute::size<0>(problem_shape_MNKL); - const int GemmN = cute::size<1>(problem_shape_MNKL); - const int GemmK = cute::size<2>(problem_shape_MNKL); - const int GemmL = cute::size<3>(problem_shape_MNKL); - - StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); - - typename CompressorKernelHost::Arguments arguments{ - {GemmM, GemmN, GemmK, GemmL}, - {datas.tensor_A.host_data(), - stride_a, - datas.tensor_A_Comp_ref.host_data(), - datas.tensor_E_ref.host_data()}, - {}}; - - const auto can_imp = CompressorKernelHost::can_implement(arguments); - if (can_imp != cutlass::Status::kSuccess) { - printf("can_implement() check failed\n"); - return false; - } - - // Relies on std::vector for RAII - auto workspace_size = - static_cast::size_type>(CompressorKernelHost::get_workspace_size(arguments)); - std::vector workspace_vector(workspace_size); - auto workspace = static_cast(workspace_vector.data()); - - cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace); - if (status != cutlass::Status::kSuccess) { - printf("initialize_workspace() failed\n"); - return false; - } - - auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace); - CompressorKernelHost::run(params); - - return true; - } - - bool compare_reference(Data& datas) - { - bool check_tensor_a_compressed = - cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view()); - if (!check_tensor_a_compressed) { - printf("A-Compressed Mismatch\n"); - } - - bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view()); - if (!check_tensor_e) { - printf("E Mismatch\n"); - } - - return check_tensor_a_compressed && check_tensor_e; - } - - bool run_auto_small() - { - return run_auto(true); - } - - bool run_auto(bool run_small = false) - { - constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; - constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; - constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; - - constexpr int GemmN = 1; - - using ProblemType = typename std::array; - - std::vector problems; - - const std::vector problems_multiplier_of_tensor_e_atom = { - // * Regular Cases (multiplier of TensorEAlignment) - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1}, - - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1}, - - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1}, - - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2}, - - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2}, - - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2}, - - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3}, - - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3}, - - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, - {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3}, - }; - - const std::vector problems_multiplier_of_tensor_e_atom_large = { - // * Large Case (multiplier of TensorEAlignment) - {TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1}, - // {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2}, - // {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3}, - }; - - const std::vector problems_multiplier_of_twochunk { - // * Corner Cases - {4, GemmN, LogicalElemsAPerChunk * 2, 1}, - {4, GemmN, LogicalElemsAPerChunk * 4, 1}, - {4, GemmN, LogicalElemsAPerChunk * 6, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, - - {4, GemmN, LogicalElemsAPerChunk * 2, 2}, - {4, GemmN, LogicalElemsAPerChunk * 4, 2}, - {4, GemmN, LogicalElemsAPerChunk * 6, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, - - {4, GemmN, LogicalElemsAPerChunk * 2, 3}, - {4, GemmN, LogicalElemsAPerChunk * 4, 3}, - {4, GemmN, LogicalElemsAPerChunk * 6, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, - }; - - const std::vector problems_multiplier_of_onechunk { - {4, GemmN, LogicalElemsAPerChunk * 1, 1}, - {4, GemmN, LogicalElemsAPerChunk * 3, 1}, - {4, GemmN, LogicalElemsAPerChunk * 5, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, - - {4, GemmN, LogicalElemsAPerChunk * 1, 2}, - {4, GemmN, LogicalElemsAPerChunk * 3, 2}, - {4, GemmN, LogicalElemsAPerChunk * 5, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, - - {4, GemmN, LogicalElemsAPerChunk * 1, 3}, - {4, GemmN, LogicalElemsAPerChunk * 3, 3}, - {4, GemmN, LogicalElemsAPerChunk * 5, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, - {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, - {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, - - {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, - {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, - {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, - {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, - - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, - - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, - {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, - }; - - // Run small only run multiplier of chunk size cases - if (run_small) { - problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); - } - // Run full run all corner cases - else { - problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end()); - problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); - problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end()); - problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end()); - } - - for (const auto& problem_shape_MNKL : problems) { - const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; - bool passed = run({GemmM, GemmN, GemmK, GemmL}); - printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL"); - CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL"); - if (not passed) { - return false; - } - } - - return true; - } - - bool run(ProblemShapeType problem_shape_MNKL) - { - // Check if valid test - if (not valid_test(problem_shape_MNKL)) { - CUTLASS_TRACE_HOST("valid_test() fail\n"); - return false; - } - - // Data Storage - Data datas; - - // Initialize Data - if (not initialize(problem_shape_MNKL, datas)) { - CUTLASS_TRACE_HOST("initialize() fail\n"); - return false; - } - - // Run Compressor (Host Ref) - if (not run_host_ref(problem_shape_MNKL, datas)) { - CUTLASS_TRACE_HOST("run_host() fail\n"); - return false; - } - - // Run Compressor (Device) - if (not run_device(problem_shape_MNKL, datas)) { - CUTLASS_TRACE_HOST("run_device() fail\n"); - return false; - } - - // Verify - if (not compare_reference(datas)) { - CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n"); - printf("compare_reference() DEVICE <-> LEGACY HOST fail\n"); - return false; - } - // else { - // printf("DEVICE <-> HOST PASS\n"); - // } - - return true; - } - - bool benchmark(ProblemShapeType problem_shape_MNKL) { - const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; - printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL); - - // Check if valid test - if (valid_test(problem_shape_MNKL) == false) { - CUTLASS_TRACE_HOST("valid_test() fail\n"); - return false; - } - - // 2 warm-up iterations and 10 timing iterations - constexpr int num_warmup = 5; - constexpr int num_iter = 10; - - // Duplicate data to mimic cold cache - Data data[num_warmup + num_iter]; - double total_time_milliseconds{0.0}; - - for (int i = 0; i < num_warmup + num_iter; ++i ) { - printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i ); - - auto& datum_i = data[i]; - - // Initialize Data - if (initialize(problem_shape_MNKL, datum_i) == false) { - CUTLASS_TRACE_HOST("initialize() fail\n"); - return false; - } - - // Run Compressor (Device) - double time_i_milliseconds{0.0f}; - if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) { - CUTLASS_TRACE_HOST("run_device() fail\n"); - return false; - } - - if ( i >= num_warmup ) { - total_time_milliseconds += time_i_milliseconds; - } - } - - const double mean_time_milliseconds = total_time_milliseconds / num_iter; - printf("Mean time (ms): %.5f\n", mean_time_milliseconds); - - return true; - } - -public: - // Data Init Setting - cutlass::Distribution::Kind init_A; - cutlass::Distribution::Kind init_A_Comp; - cutlass::Distribution::Kind init_E; - uint64_t seed; -}; - -} // namespace device -} // namespace transform -} // namespace test diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h deleted file mode 100644 index df241e3ca6e6e584af7351402d990a8028e2abed..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h +++ /dev/null @@ -1,156 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - - \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. - - Generally, - - description - compile-time constant parameters used to instantiate an operation - - configuration - runtime parameters with computationally expensive initialization - - arguments - runtime parameters that may be passed to an initialized operation with low - computational overhead -*/ - -#pragma once - -#include "cutlass/arch/mma.h" -#include "cutlass/arch/arch.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ArchMap; - -template <> struct ArchMap { - static int const kMin = 50; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 60; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 61; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 70; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 70; - static int const kMax = 75; -}; - -template struct ArchMap { - static int const kMin = 75; - static int const kMax = 1024; -}; - -template struct ArchMap { - static int const kMin = 80; - static int const kMax = 1024; -}; - -template struct ArchMap { - static int const kMin = 86; - static int const kMax = 1024; -}; - -template struct ArchMap { - static int const kMin = 89; - static int const kMax = 100; -}; - -template struct ArchMap { - static int const kMin = 90; - static int const kMax = 1024; -}; - -// Arch conditional WGMMA -template <> struct ArchMap { - static int const kMin = 90; - static int const kMax = 90; -}; - -// Arch conditional sparse WGMMA -template <> struct ArchMap { - static int const kMin = 90; - static int const kMax = 90; -}; - - -template struct ArchMap { - static int const kMin = 100; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 100; - #if (__CUDACC_VER_MAJOR__ >= 13) - static int const kMax = 110; - #else - static int const kMax = 103; - #endif // __CUDACC_VER_MAJOR__ >= 13 -}; - -template struct ArchMap { - static int const kMin = 103; - static int const kMax = 1024; -}; -template <> struct ArchMap { - static int const kMin = 103; - static int const kMax = 103; -}; - -template struct ArchMap { - static int const kMin = 120; - static int const kMax = 121; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h deleted file mode 100644 index 5e80c124e59d24cd90c7c1b0c06bcc3bedfee62f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h +++ /dev/null @@ -1,815 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct MathInstructionDescription { - - /// Shape of the target math instruction - cutlass::gemm::GemmCoord instruction_shape; - - /// Describes the data type of the internal accumulator - NumericTypeID element_accumulator; - - /// Classification of math instruction - OpcodeClassID opcode_class; - - /// Type of math operation performed - MathOperationID math_operation; - - // - // Methods - // - - MathInstructionDescription( - cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), - NumericTypeID element_accumulator = NumericTypeID::kInvalid, - OpcodeClassID opcode_class = OpcodeClassID::kInvalid, - MathOperationID math_operation = MathOperationID::kMultiplyAdd - ): - instruction_shape(instruction_shape), - element_accumulator(element_accumulator), - opcode_class(opcode_class), - math_operation(math_operation) {} - - // Equality operator - inline - bool operator==(MathInstructionDescription const& rhs) const{ - return ( - (instruction_shape == rhs.instruction_shape) && - (element_accumulator == rhs.element_accumulator) && - (opcode_class == rhs.opcode_class) && - (math_operation == rhs.math_operation)); - } - - // Inequality operator - inline - bool operator!=(MathInstructionDescription const& rhs) const { - return !(*this == rhs); - } - -}; - -/// Structure describing the tiled structure of a GEMM-like computation -struct TileDescription { - - /// Describes the shape of a threadblock (in elements) - cutlass::gemm::GemmCoord threadblock_shape; - - /// Describes the number of pipeline stages in the threadblock-scoped mainloop - int threadblock_stages; - - /// Number of warps in each logical dimension - cutlass::gemm::GemmCoord warp_count; - - /// Core math instruction - MathInstructionDescription math_instruction; - - /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. - int minimum_compute_capability; - - /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. - int maximum_compute_capability; - - /// Describes the shape of a cluster (in blocks) - cutlass::gemm::GemmCoord cluster_shape; - - // - // Methods - // - - TileDescription( - cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), - int threadblock_stages = 0, - cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), - MathInstructionDescription math_instruction = MathInstructionDescription(), - int minimum_compute_capability = 0, - int maximum_compute_capability = 0, - cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) - ): - threadblock_shape(threadblock_shape), - threadblock_stages(threadblock_stages), - warp_count(warp_count), - math_instruction(math_instruction), - minimum_compute_capability(minimum_compute_capability), - maximum_compute_capability(maximum_compute_capability), - cluster_shape(cluster_shape) { } - - // Equality operator - inline - bool operator==(TileDescription const& rhs) const{ - return ( - (threadblock_shape == rhs.threadblock_shape) && - (threadblock_stages == rhs.threadblock_stages) && - (warp_count == rhs.warp_count) && - (math_instruction == rhs.math_instruction) && - (minimum_compute_capability == rhs.minimum_compute_capability) && - (maximum_compute_capability == rhs.maximum_compute_capability)); - } - - // Inequality operator - inline - bool operator!=(TileDescription const& rhs) const { - return !(*this == rhs); - } -}; - -/// High-level description of an operation -struct OperationDescription { - - /// Unique identifier describing the operation - char const * name; - - /// Operation provider - Provider provider; - - /// Kind of operation - OperationKind kind; - - /// Describes the tiled structure of a GEMM-like computation - TileDescription tile_description; - - // - // Methods - // - OperationDescription( - char const * name = "unknown", - Provider provider = Provider::kInvalid, - OperationKind kind = OperationKind::kInvalid, - TileDescription const& tile_description = TileDescription() - ): - name(name), provider(provider), kind(kind), tile_description(tile_description) { } -}; - -/// Structure describing the properties of a tensor -struct TensorDescription { - - /// Numeric type of an individual element - NumericTypeID element; - - /// Enumerant identifying the layout function for the tensor - LayoutTypeID layout; - - /// Alignment restriction on pointers, strides, and extents - int alignment; - - /// log2() of the maximum extent of each dimension - int log_extent_range; - - /// log2() of the maximum value each relevant stride may have - int log_stride_range; - - // - // Methods - // - - TensorDescription( - NumericTypeID element = NumericTypeID::kInvalid, - LayoutTypeID layout = LayoutTypeID::kInvalid, - int alignment = 1, - int log_extent_range = 24, - int log_stride_range = 24 - ): - element(element), - layout(layout), - alignment(alignment), - log_extent_range(log_extent_range), - log_stride_range(log_stride_range) { } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all GEMM computations -struct GemmDescription : public OperationDescription { - - /// Indicates the kind of GEMM performed - GemmKind gemm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source matrix - TensorDescription C; - - /// Describes the destination matrix - TensorDescription D; - - /// Describes the sparse meta matrices - TensorDescription E; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - GemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - TensorDescription const& D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} - - GemmDescription( - OperationDescription op_desc, - GemmKind gemm_kind, - TensorDescription const& A, - TensorDescription const& B, - TensorDescription const& C, - TensorDescription const& D, - NumericTypeID element_epilogue, - SplitKMode split_k_mode, - ComplexTransform transform_A, - ComplexTransform transform_B - ): - OperationDescription(op_desc), - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -struct BlockScaleDescription { - /// Describes the SFA operand - TensorDescription SFA; - - /// Describes the SFB operand - TensorDescription SFB; - - /// Describes the SFD operand - TensorDescription SFD; - - /// Describes the input ScaleFactor VectorSize - int SFMVecSize; - int SFNVecSize; - int SFKVecSize; - - /// Describes the Output ScaleFactor VectorSize - int EpilogueSFVecSize; - - /// Describes the underlying kind of scaling: - /// Tensor Core supported (BlockScaled) or manual scaling (Blockwise) - OperationKind kind; -}; - -struct GroupedGemmDescription : public OperationDescription { - GemmDescription gemm; - std::optional block_scales; -}; - -/// Description of all GEMM computations -struct BlockScaledGemmDescription : public OperationDescription { - - /// Indicates the kind of GEMM performed - GemmKind gemm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source matrix - TensorDescription C; - - /// Describes the destination matrix - TensorDescription D; - - /// Describes the SFA operand - TensorDescription SFA; - - /// Describes the SFB operand - TensorDescription SFB; - - /// Describes the SFD operand - TensorDescription SFD; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - /// Describes the input ScaleFactor VectorSize - int SFVecSize; - - /// Describes the Output ScaleFactor VectorSize - int EpilogueSFVecSize; - - // - // Methods - // - - BlockScaledGemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - TensorDescription const& D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} - - BlockScaledGemmDescription( - OperationDescription op_desc, - GemmKind gemm_kind, - TensorDescription const& A, - TensorDescription const& B, - TensorDescription const& C, - TensorDescription const& D, - NumericTypeID element_epilogue, - SplitKMode split_k_mode, - ComplexTransform transform_A, - ComplexTransform transform_B - ): - OperationDescription(op_desc), - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -/// Description of all GEMM computations -struct BlockwiseGemmDescription : public OperationDescription { - - /// Indicates the kind of GEMM performed - GemmKind gemm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source matrix - TensorDescription C; - - /// Describes the destination matrix - TensorDescription D; - - /// Describes the SFA operand - TensorDescription SFA; - - /// Describes the SFB operand - TensorDescription SFB; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - /// Describes the input ScaleFactor VectorSize - int SFMVecSize; - int SFNVecSize; - int SFKVecSize; - - // - // Methods - // - - BlockwiseGemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - TensorDescription const& D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} - - BlockwiseGemmDescription( - OperationDescription op_desc, - GemmKind gemm_kind, - TensorDescription const& A, - TensorDescription const& B, - TensorDescription const& C, - TensorDescription const& D, - NumericTypeID element_epilogue, - SplitKMode split_k_mode, - ComplexTransform transform_A, - ComplexTransform transform_B - ): - OperationDescription(op_desc), - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description for structured sparse GEMMs. -struct SparseGemmDescription : public GemmDescription { - - /// Description structure for structured sparse GEMM - SparseGemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - TensorDescription const& D = TensorDescription(), - TensorDescription const& E = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) - {this->E = E;} -}; - -/// Description of all Reduction operations -struct ReductionDescription : public OperationDescription { - - /// Describes the data type of workspace - NumericTypeID element_workspace; - - /// Describes the data type of final output - NumericTypeID element_output; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; -}; - -/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) -struct RankKDescription : public OperationDescription { - - /// Indicates which device template is used (universal or regular) - RankKKind rank_k_kind; - - /// Number of rank update (rank k or rank 2k) - int num_ranks; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand (used only for SYR2K and HER2K) - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription C; - - /// Describes the fill mode for matrix C - FillMode fill_mode; - - /// Describes the blas mode (symmetric/hermitian) - BlasMode blas_mode; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - RankKDescription( - RankKKind rank_k_kind = RankKKind::kUniversal, - int num_ranks = 1, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - FillMode fill_mode = FillMode::kInvalid, - BlasMode blas_mode = BlasMode::kInvalid, - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - rank_k_kind(rank_k_kind), - num_ranks(num_ranks), - A(A), - B(B), - C(C), - fill_mode(fill_mode), - blas_mode(blas_mode), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all TRMM computations -struct TrmmDescription : public OperationDescription { - - /// Indicates the kind of TRMM performed - TrmmKind trmm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the side mode for matrix A - SideMode side_mode; - - /// Describes the fill mode for matrix A - FillMode fill_mode; - - /// Describes the diag type for matrix A - DiagType diag_type; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription D; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - // - // Methods - // - - TrmmDescription( - TrmmKind trmm_kind = TrmmKind::kUniversal, - TensorDescription const& A = TensorDescription(), - SideMode side_mode = SideMode::kInvalid, - FillMode fill_mode = FillMode::kInvalid, - DiagType diag_type = DiagType::kInvalid, - TensorDescription const& B = TensorDescription(), - TensorDescription const& D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone - ): - trmm_kind(trmm_kind), - A(A), - side_mode(side_mode), - fill_mode(fill_mode), - diag_type(diag_type), - B(B), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all SYMM/HEMM update computations -struct SymmDescription : public OperationDescription { - - /// Indicates which device template is used (universal or regular) - SymmKind symm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription C; - - /// Describes the side mode for matrix A - SideMode side_mode; - - /// Describes the fill mode for matrix A - FillMode fill_mode; - - /// Describes the blas mode (symmetric/hermitian) - BlasMode blas_mode; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - SymmDescription( - SymmKind symm_kind = SymmKind::kUniversal, - TensorDescription const& A = TensorDescription(), - TensorDescription const& B = TensorDescription(), - TensorDescription const& C = TensorDescription(), - SideMode side_mode = SideMode::kInvalid, - FillMode fill_mode = FillMode::kInvalid, - BlasMode blas_mode = BlasMode::kInvalid, - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - symm_kind(symm_kind), - A(A), - B(B), - C(C), - side_mode(side_mode), - fill_mode(fill_mode), - blas_mode(blas_mode), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all Conv2d operations -struct ConvDescription : public OperationDescription { - /// Describes the convolution dimension support (2D or 3D) - int conv_dim; - - /// Describes the kind of convolution - ConvKind conv_kind; - - /// Describes the type of iterator algorithm (analytic or precomputed) - IteratorAlgorithmID iterator_algorithm; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the C operand - TensorDescription C; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - // - // Methods - // - // Returns Activation TensorDescription - TensorDescription activation() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return A; - case library::ConvKind::kDgrad : return C; - case library::ConvKind::kWgrad : return B; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Filter TensorDescription - TensorDescription filter() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return B; - case library::ConvKind::kDgrad : return B; - case library::ConvKind::kWgrad : return C; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Output TensorDescription - TensorDescription output() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return C; - case library::ConvKind::kDgrad : return A; - case library::ConvKind::kWgrad : return A; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h deleted file mode 100644 index 027944eb6ac8c6e8f250d83ed33c0899adfbd3e8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h +++ /dev/null @@ -1,365 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief BLAS-like handle used to launch operations on the CUDA device. -*/ - -#pragma once - -#include -#include "cutlass/library/library.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Handle object -class Handle { -private: - - /// Host workspace - static int const kHostWorkspaceSize = (4 << 10); - - /// Provider of operations - Provider provider_; - - /// CUDA device properties - cudaDeviceProp device_; - - /// CUDA stream - cudaStream_t stream_; - - /// Device workspace - void *workspace_; - - /// Size of device workspace in bytes - size_t workspace_size_; - - /// Indicates whether scalars are host or device pointers - ScalarPointerMode scalar_pointer_mode_; - - /// Pointer to the most recently executed operation - Operation const *last_operation_; - - int device_idx_; - -public: - - /// Constructor - Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); - - /// Destructor - ~Handle(); - - /// Move constructor - Handle(Handle && handle); - - /// Move assignment operator - Handle &operator=(Handle && handle); - - // - // Persistent state accessors - // - - /// Returns compute capability of the selected device - int compute_capability() const; - - /// Sets the current CUDA stream - void set_stream(cudaStream_t stream); - - /// Gets the current CUDA stream - cudaStream_t get_stream() const; - - /// Gets the current provider - Provider get_provider() const; - - /// Sets the provider of operations - void set_provider(Provider provider); - - /// Gets the device workspace size - size_t get_workspace_size() const; - - /// Gets a pointer to the device workspace allocation in Global Memory - void *get_workspace() const; - - /// Sets the size of device workspace, invalidating calls to get_device_workspace() - void set_workspace_size(size_t bytes); - - /// Gets the scalar pointer mode - ScalarPointerMode get_scalar_pointer_mode() const; - - /// Sets the scalar pointer mode - void set_scalar_pointer_mode(ScalarPointerMode mode); - - /// Gets the most recently executed operation - Operation const *get_last_operation() const; - - // - // Computations - // - - /// Executes a GEMM computation: D <= alpha * A*B + beta * C - Status gemm( - - int M, /// GEMM M dimension - int N, /// GEMM N dimension - int K, /// GEMM K dimension - - NumericTypeID element_compute, /// Data type of internal accumulation - - NumericTypeID element_scalar, /// Data type of alpha/beta scalars - - void const *alpha, /// Pointer to alpha scalar - - NumericTypeID element_A, /// Data type of A matrix elements - LayoutTypeID layout_A, /// Layout of A matrix - ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices - - void const * ptr_A, /// Pointer to A matrix in Global Memory - int64_t lda, /// Leading dimension of A matrix - - NumericTypeID element_B, /// Data type of B matrix elements - LayoutTypeID layout_B, /// Layout of B matrix - ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices - - void const * ptr_B, /// Pointer to B matrix in Global Memory - int64_t ldb, /// Leading dimension of B matrix - - void const * beta, /// Pointer to beta scalar - - NumericTypeID element_C, /// Data type of C and D matrices - - void const * ptr_C, /// Pointer to C matrix - int64_t ldc, /// Leading dimension of C matrix - - void * ptr_D, /// Pointer to D matrix - int64_t ldd /// Leading dimension of D matrix - ); - - /// Executes a GEMM computation: D <= alpha * A*B + beta * C. - // - // Supports batched-strided, batched array or split-K serial or split-K parallel. - // - Status gemm_universal( - - GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched - - int M, /// GEMM M dimension - int N, /// GEMM N dimension - int K, /// GEMM K dimension - - int cluster_m, /// cluster shape M dimension - int cluster_n, /// cluster shape N dimension - int cluster_k, /// cluster shape K dimension - int cluster_m_fallback, /// Fallback cluster shape M dimension - int cluster_n_fallback, /// Fallback cluster shape N dimension - int cluster_k_fallback, /// Fallback cluster shape K dimension - - - NumericTypeID element_compute, /// Data type of internal accumulation - - NumericTypeID element_scalar, /// Data type of alpha/beta scalars - - void const *alpha, /// Pointer to alpha scalar - - NumericTypeID element_A, /// Data type of A matrix elements - LayoutTypeID layout_A, /// Layout of A matrix - ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices - void const * ptr_A, /// Pointer to A matrix in Global Memory - int64_t lda, /// Leading dimension of A matrix - - NumericTypeID element_B, /// Data type of B matrix elements - LayoutTypeID layout_B, /// Layout of B matrix - ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices - void const * ptr_B, /// Pointer to B matrix in Global Memory - int64_t ldb, /// Leading dimension of B matrix - - void const * beta, /// Pointer to beta scalar - - NumericTypeID element_C, /// Data type of C matrix - LayoutTypeID layout_C, /// Layout of D matrix - void const * ptr_C, /// Pointer to C matrix - int64_t ldc, /// Leading dimension of C matrix - - NumericTypeID element_D, /// Data type of D matrix - LayoutTypeID layout_D, /// Layout of D matrix - void * ptr_D, /// Pointer to D matrix - int64_t ldd, /// Leading dimension of D matrix - - int batch_count = 1, /// Batch count or number of split-K slices - - int64_t batch_stride_A = 0, /// Batch stride of A operand - int64_t batch_stride_B = 0, /// Batch stride of B operand - int64_t batch_stride_C = 0, /// Batch stride of C operand - int64_t batch_stride_D = 0 /// Batch stride of D operand - ); - - /// Planar complex GEMM - /// - /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. - /// - Status gemm_planar_complex( - - int M, /// GEMM M dimension - int N, /// GEMM N dimension - int K, /// GEMM K dimension - - NumericTypeID element_compute, /// Data type of internal accumulation - - NumericTypeID element_scalar, /// Data type of alpha/beta scalars - - void const *alpha, /// Pointer to alpha scalar - - NumericTypeID element_A, /// Data type of A matrix elements - LayoutTypeID layout_A, /// Layout of A matrix - ComplexTransform transform_A, /// Complex transformation applied to A matrix - - void const * ptr_A_real, /// Pointer to real part of A matrix - void const * ptr_A_imag, /// Pointer to imaginary part of A matrix - int64_t lda_real, /// Leading dimension of real part of A matrix - int64_t lda_imag, /// Leading dimension of imaginary part of A matrix - - NumericTypeID element_B, /// Data type of B matrix elements - LayoutTypeID layout_B, /// Layout of B matrix - ComplexTransform transform_B, /// Complex transformation applied to B matrix - - void const * ptr_B_real, /// Pointer to real part of B matrix - void const * ptr_B_imag, /// Pointer to imaginary part of B matrix - int64_t ldb_real, /// Leading dimension of real part of B matrix - int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix - - void const * beta, /// Pointer to beta scalar - - NumericTypeID element_C, /// Data type of C and D matrix - - void const * ptr_C_real, /// Pointer to real part of C matrix - void const * ptr_C_imag, /// Pointer to imaginary part of C matrix - int64_t ldc_real, /// Leading dimension of real part of C matrix - int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix - - void * ptr_D_real, /// Pointer to real part of D matrix - void * ptr_D_imag, /// Pointer to imaginary part of D matrix - int64_t ldd_real, /// Leading dimension of real part of D matrix - int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix - - int batch_count = 1, /// Number of batched GEMMs to execute - - int64_t batch_stride_A_real = 0, - int64_t batch_stride_A_imag = 0, - - int64_t batch_stride_B_real = 0, - int64_t batch_stride_B_imag = 0, - - int64_t batch_stride_C_real = 0, - int64_t batch_stride_C_imag = 0, - - int64_t batch_stride_D_real = 0, - int64_t batch_stride_D_imag = 0 - ); - - /// Planar complex GEMM loading pointers from arrays in global memory - Status gemm_planar_complex_array( - - int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) - int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) - int expected_K, /// Expected GEMM K dimension - int batch_count, /// Number of independent GEMM computations to execute - - int const *M, /// Array containing the GEMM M dimension for each batch index - int const *N, /// Array containing the GEMM N dimension for each batch index - int const *K, /// Array containing the GEMM K dimension for each batch index - - NumericTypeID element_compute, /// Data type of internal accumulation - - NumericTypeID element_scalar, /// Data type of alpha/beta scalars - - void const *alpha, /// Pointer to alpha scalar - - NumericTypeID element_A, /// Data type of A matrix elements - LayoutTypeID layout_A, /// Layout of A matrix - ComplexTransform transform_A, /// Complex transformation applied to A matrix - - void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices - void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices - - int64_t lda_real, /// Leading dimension of real part of A matrix - int64_t lda_imag, /// Leading dimension of imaginary part of A matrix - - NumericTypeID element_B, /// Data type of B matrix elements - LayoutTypeID layout_B, /// Layout of B matrix - ComplexTransform transform_B, /// Complex transformation applied to B matrix - - void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices - void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices - - int64_t ldb_real, /// Leading dimension of real part of B matrix - int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix - - void const * beta, /// Pointer to beta scalar - - NumericTypeID element_C, /// Data type of C and D matrix - - void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices - void const * const * ptr_C_imag, /// Pointer to array containing pointers to imaginary part of C matrices - - int64_t ldc_real, /// Leading dimension of real part of C matrix - int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix - - void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices - void * const * ptr_D_imag, /// Pointer to array containing pointers to imaginary part of D matrices - - int64_t ldd_real, /// Leading dimension of real part of D matrix - int64_t ldd_imag /// Leading dimension of imaginary part of D matrix - ); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Unique pointer storing the handle -using HandlePtr = std::unique_ptr; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace -Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace -Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h deleted file mode 100644 index 6764d9a6d81286c8bba0f5184b17819bfae86978..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h +++ /dev/null @@ -1,995 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - - \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. - - Generally, - - description - compile-time constant parameters used to instantiate an operation - - configuration - runtime parameters with computationally expensive initialization - - arguments - runtime parameters that may be passed to an initialized operation with low - computational overhead -*/ - -#ifndef CUTLASS_LIBRARY_LIBRARY_H -#define CUTLASS_LIBRARY_LIBRARY_H - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include -#include -#include -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass/library/types.h" -#include "cutlass/library/descriptions.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/blas3.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Mode of Universal GEMM -using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Base class for all operations -class Operation { -public: - - virtual ~Operation() { } - - virtual OperationDescription const & description() const = 0; - - virtual Status can_implement( - void const *configuration, - void const *arguments) const = 0; - - virtual uint64_t get_host_workspace_size( - void const *configuration) const = 0; - - virtual uint64_t get_device_workspace_size( - void const *configuration, - void const *arguments = nullptr) const = 0; - - virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const = 0; - - // Originally designed for metadata, but should be useful for FP8/6/4 too. - virtual Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, - uint8_t **profiler_workspace_ptrs, - int problem_count, - cudaStream_t stream = nullptr) { - return Status::kErrorNotSupported; - } - - virtual Status run( - void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const = 0; - - // Set arguments that should only be set once before verifying or profiling the kernel. - // This should encompass any expensive operations that don't vary from run to run - // (e.g., max_active_clusters). - virtual Status initialize_with_arguments(void* arguments_ptr) const { - return Status::kSuccess; - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for basic GEMM operations -// -// OperationKind: Gemm -// GemmKind: Gemm -// -struct GemmConfiguration { - - /// GEMM problem size - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of C matrix - int64_t ldc{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - /// Number of partitions of K dimension - int split_k_slices{0}; -}; - -/// Arguments for GEMM -struct GemmArguments { - - /// Pointer to A matrix - void const *A{nullptr}; - - /// Pointer to B matrix - void const *B{nullptr}; - - /// Pointer to C matrix - void const *C{nullptr}; - - /// Pointer to D matrix - void *D{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - /// Whether to use PDL when launching the kernel - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for batched GEMM in which multiple matrix products are computed -// -// OperationKind: Gemm -// GemmKind: Batched - -struct GemmBatchedConfiguration { - - /// GEMM problem size - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of C matrix - int64_t ldc{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - /// Stride between instances of the A matrix in memory - int64_t batch_stride_A{0}; - - /// Stride between instances of the B matrix in memory - int64_t batch_stride_B{0}; - - /// Stride between instances of the C matrix in memory - int64_t batch_stride_C{0}; - - /// Stride between instances of the D matrix in memory - int64_t batch_stride_D{0}; - - /// Number of GEMMs in batch - int batch_count{1}; -}; - -/// Arguments to batched GEMM -using GemmBatchedArguments = GemmArguments; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for batched GEMM in which multiple matrix products are computed -// -// OperationKind: Gemm -// GemmKind: Array - -struct GemmArrayConfiguration { - - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of C matrix - int64_t ldc{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - int batch_count{1}; -}; - -/// Arguments for GEMM - used by all the GEMM operations -struct GemmArrayArguments { - void const * const *A{nullptr}; - void const * const *B{nullptr}; - void const * const *C{nullptr}; - void * const *D{nullptr}; - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex -// -// OperationKind: Gemm -// GemmKind: Universal - -struct GemmUniversalConfiguration { - - GemmUniversalMode mode{GemmUniversalMode::kGemm}; - gemm::GemmCoord problem_size{}; - gemm::GemmCoord cluster_shape{}; - gemm::GemmCoord cluster_shape_fallback{}; - int batch_count{1}; - - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - int64_t ldd{0}; - - int device_count{1}; -}; - -enum class Sm90MixedInputWiderOperand { - A = 0, - B = 1 -}; - -struct GemmUniversalArguments { - // NOTE: these are replicated for 3.0 interfaces - gemm::GemmCoord problem_size{}; - gemm::GemmCoord cluster_shape{}; - gemm::GemmCoord cluster_shape_fallback{}; - int batch_count{1}; - - void const *A{nullptr}; - void const *B{nullptr}; - void const *C{nullptr}; - void *D{nullptr}; - - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - - // NOTE: these are replicated for 3.0 interfaces - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - int64_t ldd{0}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - - // Needed for some 3.x kernels - int sm_count{0}; - library::RasterOrder raster_order{}; - library::RuntimeDatatype runtime_input_datatype_a{}; - library::RuntimeDatatype runtime_input_datatype_b{}; - int swizzle_size{1}; - int split_k_slices{1}; - - // For SM90 mixed input dtype kernels - bool is_sm90_mixed_dtype{false}; - Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; - bool generate_scale_and_zero{false}; - bool generate_dequantized_AB{false}; - void *Scale{nullptr}; // Scale tensor - void *Zero{nullptr}; // Zero tensor - void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification - void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle - void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 - - int device_index{0}; - - bool use_pdl{false}; -}; - -/// Block Scaled GEMM -// -// OperationKind: kBlockScaledGemm -// GemmKind: Universal - -struct BlockScaledGemmArguments { - // NOTE: these are replicated for 3.0 interfaces - gemm::GemmCoord problem_size{}; - gemm::GemmCoord cluster_shape{}; - gemm::GemmCoord cluster_shape_fallback{}; - int batch_count{1}; - - void const *A{nullptr}; - void const *B{nullptr}; - void const *SFA{nullptr}; - void const *SFB{nullptr}; - void const *C{nullptr}; - void *D{nullptr}; - void *SFD{nullptr}; - - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - - // NOTE: these are replicated for 3.0 interfaces - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - int64_t ldd{0}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - - // Needed for ScaleFactor Generation - void const *norm_constant{nullptr}; - - // Needed for some 3.x kernels - int sm_count{0}; - library::RasterOrder raster_order{}; - int swizzle_size{1}; - int split_k_slices{1}; - - library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; - library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; - - bool use_pdl{false}; -}; - -/// Blockwise GEMM -// -// OperationKind: kBlockwiseGemm -// GemmKind: Universal - -struct BlockwiseGemmArguments { - // NOTE: these are replicated for 3.0 interfaces - gemm::GemmCoord problem_size{}; - gemm::GemmCoord cluster_shape{}; - gemm::GemmCoord cluster_shape_fallback{}; - int batch_count{1}; - - void const *A{nullptr}; - void const *B{nullptr}; - void const *SFA{nullptr}; - void const *SFB{nullptr}; - void const *C{nullptr}; - void *D{nullptr}; - - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - - // NOTE: these are replicated for 3.0 interfaces - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - int64_t ldd{0}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - - int sf_m_vec_size{0}; - int sf_n_vec_size{0}; - int sf_k_vec_size{0}; - - // Needed for some 3.x kernels - int sm_count{0}; - library::RasterOrder raster_order{}; - int swizzle_size{1}; - int split_k_slices{1}; - - library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; - library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; - - bool use_pdl{false}; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Complex valued GEMM in which real and imaginary parts are separated by a stride -// -// OperationKind: Gemm -// GemmKind: Planar complex - -struct GemmPlanarComplexConfiguration { - - GemmUniversalMode mode{GemmUniversalMode::kGemm}; - gemm::GemmCoord problem_size{}; - int batch_count{1}; - int64_t lda_real{0}; - int64_t lda_imag{0}; - int64_t ldb_real{0}; - int64_t ldb_imag{0}; - int64_t ldc_real{0}; - int64_t ldc_imag{0}; - int64_t ldd_real{0}; - int64_t ldd_imag{0}; -}; - -/// Arguments for planar complex GEMMs -struct GemmPlanarComplexArguments { - - void const *A_real{nullptr}; - void const *A_imag{nullptr}; - void const *B_real{nullptr}; - void const *B_imag{nullptr}; - void const *C_real{nullptr}; - void const *C_imag{nullptr}; - void *D_real{nullptr}; - void *D_imag{nullptr}; - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - - int64_t batch_stride_A_real{0}; - int64_t batch_stride_A_imag{0}; - int64_t batch_stride_B_real{0}; - int64_t batch_stride_B_imag{0}; - int64_t batch_stride_C_real{0}; - int64_t batch_stride_C_imag{0}; - int64_t batch_stride_D_real{0}; - int64_t batch_stride_D_imag{0}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This is a special form of planar complex which loads pointers and problem size -/// from memory. -struct GemmPlanarComplexArrayConfiguration { - - gemm::GemmCoord problem_size{}; - int batch_count{1}; - - int64_t lda_real{0}; - int64_t lda_imag{0}; - int64_t ldb_real{0}; - int64_t ldb_imag{0}; - int64_t ldc_real{0}; - int64_t ldc_imag{0}; - int64_t ldd_real{0}; - int64_t ldd_imag{0}; -}; - -/// Arguments for planar complex GEMMs -struct GemmPlanarComplexArrayArguments { - - int const *M{nullptr}; - int const *N{nullptr}; - int const *K{nullptr}; - - void const * const * A_real{nullptr}; - void const * const * A_imag{nullptr}; - void const * const * B_real{nullptr}; - void const * const * B_imag{nullptr}; - void const * const * C_real{nullptr}; - void const * const * C_imag{nullptr}; - void * const * D_real{nullptr}; - void * const * D_imag{nullptr}; - - void const * alpha{nullptr}; - void const * beta{nullptr}; - ScalarPointerMode pointer_mode{}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Grouped GEMM supporting -// -// OperationKind: Gemm -// GemmKind: Grouped - -struct GemmGroupedConfiguration { - int problem_count{0}; - // GemmGroupedConfiguration is passed to initialize(), which - // is responsible for allocating the device-side stride storage. - int64_t* lda; - int64_t* ldb; - int64_t* ldc; - - cute::Shape* problem_sizes_3x_host; -}; - -struct GemmGroupedArguments { - int problem_count{}; - gemm::GemmCoord* problem_sizes{nullptr}; - - void* ptr_A{nullptr}; - void* ptr_B{nullptr}; - void* ptr_C{nullptr}; - void* ptr_D{nullptr}; - - int64_t* lda{nullptr}; - int64_t* ldb{nullptr}; - int64_t* ldc{nullptr}; - int64_t* ldd{nullptr}; - - void const *alpha{nullptr}; - void const *beta{nullptr}; - ScalarPointerMode pointer_mode{}; - bool use_pdl{false}; - - gemm::GemmCoord cluster_shape{}; - gemm::GemmCoord cluster_shape_fallback{}; - - library::RasterOrder raster_order{}; - library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; - library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; - int swizzle_size{1}; - - // these should really be in the configuration but staying consistent with GEMM - int sm_count{0}; - int max_active_clusters{0}; - - // The user is responsible for allocating storage for problem sizes. - // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we - // unfortunately need to have both options in this struct, and the - // underlying operation uses the one it needs. - cute::Shape* problem_sizes_3x; - cute::Shape* problem_sizes_3x_host; -}; - -struct GroupedGemmBlockScaledArguments : GemmGroupedArguments { - void* SFA{nullptr}; - void* SFB{nullptr}; - void* SFD{nullptr}; - void* norm_constant{nullptr}; -}; - -struct GroupedGemmBlockwiseArguments : GemmGroupedArguments { - void* SFA{nullptr}; - void* SFB{nullptr}; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// OperationKind: kSparseGemm -// - -/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity. -struct SparseGemmConfiguration { - - GemmUniversalMode mode{GemmUniversalMode::kGemm}; - gemm::GemmCoord problem_size{}; - int batch_count{1}; /// number of sparse matrix products in batch - int64_t lda{0}; /// leading dimension of A operand - int64_t ldb{0}; /// leading dimension of B operand - int64_t ldc{0}; /// leading dimension of C operand - int64_t ldd{0}; /// leading dimension of D operand - int64_t lde{0}; /// leading dimension of E operand (metadata matrix) - int64_t batch_stride_A{0}; // stride between matrices - int64_t batch_stride_B{0}; // stride between matrices - int64_t batch_stride_C{0}; // stride between matrices - int64_t batch_stride_D{0}; // stride between matrices - int64_t batch_stride_E{0}; // stride between matrices -}; - -/// Arguments for sparse GEMMs -struct SparseGemmArguments { - void const *A{nullptr}; /// pointer to A matrix - void const *B{nullptr}; /// pointer to B matrix - void const *C{nullptr}; /// pointer to C matrix - void *D{nullptr}; /// pointer to D matrix - void const *E{nullptr}; /// pointer to E matrix (metadata) - void const *alpha{nullptr}; /// pointer to alpha scalar - void const *beta{nullptr}; /// pointer to beta scalar - ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host - /// or device pointers. - bool use_pdl{false}; /// Whether to use PDL when launching the kernel -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for basic Rank K update operations -// -// OperationKind: (Syrk, Herk, Syr2k, Her2k) -// RankKKind: Universal -// -struct RankKConfiguration { - - /// SYRK problem size - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of C matrix - int64_t ldc{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - /// Batch Count - int batch_count{1}; -}; - -/// Arguments for (Syrk, Herk, Syr2k, Her2k) -struct RankKArguments { - - /// Pointer to A matrix - void const *A{nullptr}; - - /// Pointer to B matrix (used only for Syr2k and Her2k) - void const *B{nullptr}; - - /// Pointer to C matrix - void const *C{nullptr}; - - /// Pointer to D matrix - void *D{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for basic TRMM operations -// -// OperationKind: Trmm -// TrmmKind: Universal -// -struct TrmmConfiguration { - - /// TRMM problem size - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - /// Batch Count - int batch_count{1}; -}; - -/// Arguments for TRMM -struct TrmmArguments { - - /// Pointer to A matrix - void const *A{nullptr}; - - /// Pointer to B matrix - void const *B{nullptr}; - - /// Pointer to D matrix - void *D{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_D{0}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for basic SYMM/HEMM update operations -// -// OperationKind: (Symm, Hemm) -// SymmKind: Universal -// -struct SymmConfiguration { - - /// SYMM/HEMM problem size - gemm::GemmCoord problem_size{}; - - /// Leading dimension of A matrix - int64_t lda{0}; - - /// Leading dimension of B matrix - int64_t ldb{0}; - - /// Leading dimension of C matrix - int64_t ldc{0}; - - /// Leading dimension of D matrix - int64_t ldd{0}; - - /// Batch Count - int batch_count{1}; -}; - -/// Arguments for (Symm, Hemm) -struct SymmArguments { - - /// Pointer to A matrix - void const *A{nullptr}; - - /// Pointer to B matrix - void const *B{nullptr}; - - /// Pointer to C matrix - void const *C{nullptr}; - - /// Pointer to D matrix - void *D{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - int64_t batch_stride_A{0}; - int64_t batch_stride_B{0}; - int64_t batch_stride_C{0}; - int64_t batch_stride_D{0}; - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Two dimensional convolution -// -// OperationKind: Conv2d -// -struct Conv2dConfiguration { - - conv::SplitKMode split_k_mode; - - /// Conv2d problem size - // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) - // also includes (split_k_slices, groups) - conv::Conv2dProblemSize problem_size{}; - - // stride of operand A - std::vector stride_a{}; - - // stride of operand B - std::vector stride_b{}; - - // stride of operand C - std::vector stride_c{}; -}; - - -/// Three dimensional convolution -// -// OperationKind: Conv3d -// -struct Conv3dConfiguration { - - conv::SplitKMode split_k_mode{}; - - /// Conv2d problem size - // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) - // also includes (split_k_slices, groups) - conv::Conv3dProblemSize problem_size{}; - - /// Layout object for activations tensor - layout::TensorNDHWC layout_activations{}; - - /// Layout object for filters tensor - layout::TensorNDHWC layout_filters{}; - - /// Layout object for source tensor - layout::TensorNDHWC layout_source{}; - - /// Layout object for output tensor - layout::TensorNDHWC layout_output{}; - - // - // Methods - // - - // Mapping functions (A,B,C -> activation,filter,output) - layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { - switch (conv_kind) { - case library::ConvKind::kFprop: return layout_activations; - case library::ConvKind::kDgrad: return layout_output; - case library::ConvKind::kWgrad: return layout_output; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { - switch (conv_kind) { - case library::ConvKind::kFprop: return layout_filters; - case library::ConvKind::kDgrad: return layout_filters; - case library::ConvKind::kWgrad: return layout_activations; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { - switch (conv_kind) { - case library::ConvKind::kFprop: return layout_output; - case library::ConvKind::kDgrad: return layout_activations; - case library::ConvKind::kWgrad: return layout_filters; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } -}; - -/// Arguments for CONV -struct ConvArguments { - - ///////////////////////////////////////////////////////// - /// ImplicitGemm matrices A, B, C, D - ///////////////////////////////////////////////////////// - /// pointer to implicit gemm matrix A - void const *A{nullptr}; - - /// pointer to implicit gemm matrix B - void const *B{nullptr}; - - /// pointer to reordered matrix B - void const *reordered_B{nullptr}; - - /// pointer to implicit gemm matrix C - void const *C{nullptr}; - - /// pointer to implicit gemm destination matrix D - void *D{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - /// Whether to use PDL when launching the kernel - bool use_pdl{false}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Configuration for Reduction operations -// -// OperationKind: Reduction -// -struct ReductionConfiguration { - - /// Reduction problem size - MatrixCoord problem_size{}; - - /// Number of partitions to reduce - int partitions{0}; - - /// Number of elements between each partition - int64_t partition_stride{0}; - - /// leading dimension of 'w'orkspace operand - int64_t ldw{0}; - - /// leading dimension of 's'ource operand - int64_t lds{0}; - - /// leading dimension of 'd'estination operand - int64_t ldd{0}; -}; - -/// Arguments for Reduction -struct ReductionArguments { - - /// Pointer to workspace matrix - void const *workspace{nullptr}; - - /// Pointer to source matrix - void const *source{nullptr}; - - /// Pointer to destination matrix - void *destination{nullptr}; - - /// pointer to reference matrix - void *reference{nullptr}; - - /// Host or device pointer to alpha scalar - void const *alpha{nullptr}; - - /// Host or device pointer to beta scalar - void const *beta{nullptr}; - - /// Enumerant indicating whether alpha/beta point to host or device memory - ScalarPointerMode pointer_mode{}; - - /// Whether to use PDL when launching the kernel - bool use_pdl{false}; -}; - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#endif diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h deleted file mode 100644 index c4fb0ee8ca32124450b1063cc3613078e600479d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h +++ /dev/null @@ -1,114 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Manifest of CUTLASS Library - - This is the root of the data structure containing CUTLASS objects -*/ - -#pragma once - -#include -#include -#include - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "library.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Forward declaration -class Manifest; - -// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) -void initialize_all(Manifest &manifest); - -// init and insert all reduction op in manifest object (manually instantiated in library/reduction) -void initialize_all_reduction_op(Manifest &manifest); - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -/// List of operations -using OperationVector = std::vector>; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Manifest of CUTLASS Library -class Manifest { -private: - - /// Operation provider - Provider provider_; - - /// Global list of operations - OperationVector operations_; - -public: - Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } - - /// Top-level initialization - Status initialize(); - - /// Used for initialization - void reserve(size_t operation_count); - - /// Graceful shutdown - Status release(); - - /// Appends an operation and takes ownership - void append(Operation *operation_ptr) {\ - // This function is inline s.t. it is present in generated libraries - // without having to compile or link in manifest.cpp - operations_.emplace_back(operation_ptr); - } - - /// Returns an iterator to the first operation - OperationVector const &operations() const; - - /// Returns a const iterator - OperationVector::const_iterator begin() const; - - /// Returns a const iterator - OperationVector::const_iterator end() const; -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h deleted file mode 100644 index f36232c8dc833e2b24d681686f6662e79b7ecd0a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h +++ /dev/null @@ -1,905 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - \file - \brief Defines a data structure in which a set of functionally equivalent library::Operation - instances may be queried. -*/ - -#pragma once -#include -#include -#include -#include - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/util.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Data Structures for Gemm Functional Maps -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tuple uniquely identifying Gemm functional behavior -struct GemmFunctionalKey { - - Provider provider; - GemmKind gemm_kind; - NumericTypeID element_compute; - NumericTypeID element_scalar; - NumericTypeID element_A; - LayoutTypeID layout_A; - ComplexTransform transform_A; - NumericTypeID element_B; - LayoutTypeID layout_B; - ComplexTransform transform_B; - NumericTypeID element_C; - LayoutTypeID layout_C; - NumericTypeID element_D; - LayoutTypeID layout_D; - - // - // Methods - // - - inline - GemmFunctionalKey( - Provider provider, - GemmKind gemm_kind = GemmKind::kGemm, - NumericTypeID element_compute = NumericTypeID::kF32, - NumericTypeID element_scalar = NumericTypeID::kF32, - NumericTypeID element_A = NumericTypeID::kF16, - LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, - ComplexTransform transform_A = ComplexTransform::kNone, - NumericTypeID element_B = NumericTypeID::kF16, - LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, - ComplexTransform transform_B = ComplexTransform::kNone, - NumericTypeID element_C = NumericTypeID::kF16, - LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, - NumericTypeID element_D = NumericTypeID::kF16, - LayoutTypeID layout_D = LayoutTypeID::kColumnMajor - ): - provider(provider), - gemm_kind(gemm_kind), - element_compute(element_compute), - element_scalar(element_scalar), - element_A(element_A), - layout_A(layout_A), - transform_A(transform_A), - element_B(element_B), - layout_B(layout_B), - transform_B(transform_B), - element_C(element_C), - layout_C(layout_C), - element_D(element_D), - layout_D(layout_D) - { } - - inline - bool operator==(GemmFunctionalKey const &rhs) const { - return - (provider == rhs.provider) && - (gemm_kind == rhs.gemm_kind) && - (element_compute == rhs.element_compute) && - (element_scalar == rhs.element_scalar) && - (element_A == rhs.element_A) && - (layout_A == rhs.layout_A) && - (transform_A == rhs.transform_A) && - (element_B == rhs.element_B) && - (layout_B == rhs.layout_B) && - (transform_B == rhs.transform_B) && - (element_C == rhs.element_C) && - (layout_C == rhs.layout_C) && - (element_D == rhs.element_D) && - (layout_D == rhs.layout_D); - } - - inline - bool operator!=(GemmFunctionalKey const &rhs) const { - return !(*this == rhs); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -inline -std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { - - out << "{\n" - << " provider: " << to_string(k.provider) << "\n" - << " gemm_kind: " << to_string(k.gemm_kind) << "\n" - << " element_compute: " << to_string(k.element_compute) << "\n" - << " element_scalar: " << to_string(k.element_scalar) << "\n" - << " element_A: " << to_string(k.element_A) << "\n" - << " layout_A: " << to_string(k.layout_A) << "\n" - << " transform_A: " << to_string(k.transform_A) << "\n" - << " element_B: " << to_string(k.element_B) << "\n" - << " layout_B: " << to_string(k.layout_B) << "\n" - << " transform_B: " << to_string(k.transform_B) << "\n" - << " element_C: " << to_string(k.element_C) << "\n" - << " layout_C: " << to_string(k.layout_C) << "\n" - << " element_D: " << to_string(k.element_D) << "\n" - << " layout_D: " << to_string(k.layout_D) << "\n" - << "}"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Hash function for GemmFunctionalKey -struct GemmFunctionalKeyHasher { - using IntHash = std::hash; - - inline - static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); - } - - inline - size_t operator()(GemmFunctionalKey const &key) const { - IntHash hash; - - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ - rotl(hash(int(key.element_compute)), 3) ^ - rotl(hash(int(key.element_scalar)), 4) ^ - rotl(hash(int(key.element_A)), 5) ^ - rotl(hash(int(key.layout_A)), 6) ^ - rotl(hash(int(key.transform_A)), 7) ^ - rotl(hash(int(key.element_B)), 8) ^ - rotl(hash(int(key.layout_B)), 9) ^ - rotl(hash(int(key.transform_B)), 10) ^ - rotl(hash(int(key.element_C)), 11) ^ - rotl(hash(int(key.layout_C)), 12) ^ - rotl(hash(int(key.element_D)), 13) ^ - rotl(hash(int(key.layout_D)), 14); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Establishes a partial ordering to search for GEMM operators -struct GemmPreferenceKey { - - int compute_capability; - int alignment; - - // - // Methods - // - - GemmPreferenceKey(): compute_capability(), alignment() { } - - GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } - - bool operator<(GemmPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || - ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); - } - - bool operator==(GemmPreferenceKey const &rhs) const { - return compute_capability == rhs.compute_capability; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -inline -std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { - out << "{\n" - << "compute_capability : " << key.compute_capability << std::endl - << "alignment : " << key.alignment << std::endl - << "}"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Maps minimum compute capability onto a vector of possible operations -using GemmOperationVectorMap = std::map< - GemmPreferenceKey, - std::vector ->; - -/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -using GemmOperationFunctionalMap = std::unordered_map< - GemmFunctionalKey, - GemmOperationVectorMap, - GemmFunctionalKeyHasher ->; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Data Structures for BlockScaled Gemm Functional Maps -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tuple uniquely identifying Gemm functional behavior -struct BlockScaledGemmFunctionalKey { - - Provider provider; - GemmKind gemm_kind; - OperationKind kind; - NumericTypeID element_compute; - NumericTypeID element_scalar; - NumericTypeID element_A; - LayoutTypeID layout_A; - NumericTypeID element_SFA; - NumericTypeID element_B; - LayoutTypeID layout_B; - NumericTypeID element_SFB; - NumericTypeID element_C; - LayoutTypeID layout_C; - NumericTypeID element_D; - LayoutTypeID layout_D; - NumericTypeID element_SFD; - LayoutTypeID layout_SFD; - int SFVecSize; - int EpilogueSFVecSize; - // - // Methods - // - - inline - BlockScaledGemmFunctionalKey( - Provider provider, - GemmKind gemm_kind = GemmKind::kGemm, - OperationKind kind = OperationKind::kBlockScaledGemm, - NumericTypeID element_compute = NumericTypeID::kF32, - NumericTypeID element_scalar = NumericTypeID::kF32, - NumericTypeID element_A = NumericTypeID::kF16, - LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, - NumericTypeID element_SFA = NumericTypeID::kF16, - NumericTypeID element_B = NumericTypeID::kF16, - LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, - NumericTypeID element_SFB = NumericTypeID::kF16, - NumericTypeID element_C = NumericTypeID::kF16, - LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, - NumericTypeID element_D = NumericTypeID::kF16, - LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, - NumericTypeID element_SFD = NumericTypeID::kF16, - LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, - int sf_vec_size = 32 - , int epilogue_sf_vec_size = 32 - ): - provider(provider), - gemm_kind(gemm_kind), - kind(kind), - element_compute(element_compute), - element_scalar(element_scalar), - element_A(element_A), - layout_A(layout_A), - element_SFA(element_SFA), - element_B(element_B), - layout_B(layout_B), - element_SFB(element_SFB), - element_C(element_C), - layout_C(layout_C), - element_D(element_D), - layout_D(layout_D), - element_SFD(element_SFD), - layout_SFD(layout_SFD), - SFVecSize(sf_vec_size) - , EpilogueSFVecSize(epilogue_sf_vec_size) - { } - - inline - bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { - return - (provider == rhs.provider) && - (gemm_kind == rhs.gemm_kind) && - (kind == rhs.kind) && - (element_compute == rhs.element_compute) && - (element_scalar == rhs.element_scalar) && - (element_A == rhs.element_A) && - (layout_A == rhs.layout_A) && - (element_SFA == rhs.element_SFA) && - (element_B == rhs.element_B) && - (layout_B == rhs.layout_B) && - (element_SFB == rhs.element_SFB) && - (element_C == rhs.element_C) && - (layout_C == rhs.layout_C) && - (element_D == rhs.element_D) && - (layout_D == rhs.layout_D) && - (element_SFD == rhs.element_SFD) && - (layout_SFD == rhs.layout_SFD) && - (SFVecSize == rhs.SFVecSize) - && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) - ; - } - - inline - bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { - return !(*this == rhs); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -inline -std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { - - out << "{\n" - << " provider: " << to_string(k.provider) << "\n" - << " gemm_kind: " << to_string(k.gemm_kind) << "\n" - << " kind: " << to_string(k.kind) << "\n" - << " element_compute: " << to_string(k.element_compute) << "\n" - << " element_scalar: " << to_string(k.element_scalar) << "\n" - << " element_A: " << to_string(k.element_A) << "\n" - << " layout_A: " << to_string(k.layout_A) << "\n" - << " element_SFA: " << to_string(k.element_SFA) << "\n" - << " element_B: " << to_string(k.element_B) << "\n" - << " layout_B: " << to_string(k.layout_B) << "\n" - << " element_SFB: " << to_string(k.element_SFB) << "\n" - << " element_C: " << to_string(k.element_C) << "\n" - << " layout_C: " << to_string(k.layout_C) << "\n" - << " element_D: " << to_string(k.element_D) << "\n" - << " layout_D: " << to_string(k.layout_D) << "\n" - << " element_SFD: " << to_string(k.element_SFD) << "\n" - << " layout_SFD: " << to_string(k.layout_SFD) << "\n" - << " SFVecSize: " << k.SFVecSize << "\n" - << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" - << "}"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Hash function for BlockScaledGemmFunctionalKeyHasher -struct BlockScaledGemmFunctionalKeyHasher { - using IntHash = std::hash; - - inline - static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); - } - - inline - size_t operator()(BlockScaledGemmFunctionalKey const &key) const { - IntHash hash; - - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ - rotl(hash(int(key.kind)), 3) ^ - rotl(hash(int(key.element_compute)), 4) ^ - rotl(hash(int(key.element_scalar)), 5) ^ - rotl(hash(int(key.element_A)), 6) ^ - rotl(hash(int(key.layout_A)), 7) ^ - rotl(hash(int(key.element_SFA)), 8) ^ - rotl(hash(int(key.element_B)), 9) ^ - rotl(hash(int(key.layout_B)), 10) ^ - rotl(hash(int(key.element_SFB)), 11) ^ - rotl(hash(int(key.element_C)), 12) ^ - rotl(hash(int(key.layout_C)), 13) ^ - rotl(hash(int(key.element_D)), 14) ^ - rotl(hash(int(key.layout_D)), 15) ^ - rotl(hash(int(key.element_SFD)), 16) ^ - rotl(hash(int(key.layout_SFD)), 17) ^ - rotl(hash(int(key.SFVecSize)), 18) ^ - rotl(hash(int(key.EpilogueSFVecSize)), 19) - ; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -using BlockScaledGemmOperationFunctionalMap = std::unordered_map< - BlockScaledGemmFunctionalKey, - GemmOperationVectorMap, - BlockScaledGemmFunctionalKeyHasher ->; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Data Structures for Blockwise Gemm Functional Maps -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tuple uniquely identifying Gemm functional behavior -struct BlockwiseGemmFunctionalKey { - - Provider provider; - GemmKind gemm_kind; - OperationKind kind; - NumericTypeID element_compute; - NumericTypeID element_scalar; - NumericTypeID element_A; - LayoutTypeID layout_A; - NumericTypeID element_SFA; - NumericTypeID element_B; - LayoutTypeID layout_B; - NumericTypeID element_SFB; - NumericTypeID element_C; - LayoutTypeID layout_C; - NumericTypeID element_D; - LayoutTypeID layout_D; - int SFMVecSize; - int SFNVecSize; - int SFKVecSize; - // - // Methods - // - - inline - BlockwiseGemmFunctionalKey( - Provider provider, - GemmKind gemm_kind = GemmKind::kGemm, - OperationKind kind = OperationKind::kBlockwiseGemm, - NumericTypeID element_compute = NumericTypeID::kF32, - NumericTypeID element_scalar = NumericTypeID::kF32, - NumericTypeID element_A = NumericTypeID::kF16, - LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, - NumericTypeID element_SFA = NumericTypeID::kF16, - NumericTypeID element_B = NumericTypeID::kF16, - LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, - NumericTypeID element_SFB = NumericTypeID::kF16, - NumericTypeID element_C = NumericTypeID::kF16, - LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, - NumericTypeID element_D = NumericTypeID::kF16, - LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, - int sfm_vec_size = 32, - int sfn_vec_size = 32, - int sfk_vec_size = 32 - ): - provider(provider), - gemm_kind(gemm_kind), - kind(kind), - element_compute(element_compute), - element_scalar(element_scalar), - element_A(element_A), - layout_A(layout_A), - element_SFA(element_SFA), - element_B(element_B), - layout_B(layout_B), - element_SFB(element_SFB), - element_C(element_C), - layout_C(layout_C), - element_D(element_D), - layout_D(layout_D), - SFMVecSize(sfm_vec_size), - SFNVecSize(sfn_vec_size), - SFKVecSize(sfk_vec_size) - { } - - inline - bool operator==(BlockwiseGemmFunctionalKey const &rhs) const { - return - (provider == rhs.provider) && - (gemm_kind == rhs.gemm_kind) && - (kind == rhs.kind) && - (element_compute == rhs.element_compute) && - (element_scalar == rhs.element_scalar) && - (element_A == rhs.element_A) && - (layout_A == rhs.layout_A) && - (element_SFA == rhs.element_SFA) && - (element_B == rhs.element_B) && - (layout_B == rhs.layout_B) && - (element_SFB == rhs.element_SFB) && - (element_C == rhs.element_C) && - (layout_C == rhs.layout_C) && - (element_D == rhs.element_D) && - (layout_D == rhs.layout_D) && - (SFMVecSize == rhs.SFMVecSize) && - (SFNVecSize == rhs.SFNVecSize) && - (SFKVecSize == rhs.SFKVecSize); - } - - inline - bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const { - return !(*this == rhs); - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -inline -std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) { - - out << "{\n" - << " provider: " << to_string(k.provider) << "\n" - << " gemm_kind: " << to_string(k.gemm_kind) << "\n" - << " kind: " << to_string(k.kind) << "\n" - << " element_compute: " << to_string(k.element_compute) << "\n" - << " element_scalar: " << to_string(k.element_scalar) << "\n" - << " element_A: " << to_string(k.element_A) << "\n" - << " layout_A: " << to_string(k.layout_A) << "\n" - << " element_SFA: " << to_string(k.element_SFA) << "\n" - << " element_B: " << to_string(k.element_B) << "\n" - << " layout_B: " << to_string(k.layout_B) << "\n" - << " element_SFB: " << to_string(k.element_SFB) << "\n" - << " element_C: " << to_string(k.element_C) << "\n" - << " layout_C: " << to_string(k.layout_C) << "\n" - << " element_D: " << to_string(k.element_D) << "\n" - << " layout_D: " << to_string(k.layout_D) << "\n" - << " SFMVecSize: " << k.SFMVecSize << "\n" - << " SFNVecSize: " << k.SFNVecSize << "\n" - << " SFKVecSize: " << k.SFKVecSize << "\n" - << "}"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Hash function for BlockwiseGemmFunctionalKeyHasher -struct BlockwiseGemmFunctionalKeyHasher { - using IntHash = std::hash; - - inline - static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); - } - - inline - size_t operator()(BlockwiseGemmFunctionalKey const &key) const { - IntHash hash; - - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ - rotl(hash(int(key.kind)), 3) ^ - rotl(hash(int(key.element_compute)), 4) ^ - rotl(hash(int(key.element_scalar)), 5) ^ - rotl(hash(int(key.element_A)), 6) ^ - rotl(hash(int(key.layout_A)), 7) ^ - rotl(hash(int(key.element_SFA)), 8) ^ - rotl(hash(int(key.element_B)), 9) ^ - rotl(hash(int(key.layout_B)), 10) ^ - rotl(hash(int(key.element_SFB)), 11) ^ - rotl(hash(int(key.element_C)), 12) ^ - rotl(hash(int(key.layout_C)), 13) ^ - rotl(hash(int(key.element_D)), 14) ^ - rotl(hash(int(key.layout_D)), 15) ^ - rotl(hash(int(key.SFMVecSize)), 16) ^ - rotl(hash(int(key.SFNVecSize)), 17) ^ - rotl(hash(int(key.SFKVecSize)), 18) - ; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -using BlockwiseGemmOperationFunctionalMap = std::unordered_map< - BlockwiseGemmFunctionalKey, - GemmOperationVectorMap, - BlockwiseGemmFunctionalKeyHasher ->; - - - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Data Structures for Conv Functional Maps -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tuple uniquely identifying conv2d functional behavior -struct ConvFunctionalKey { - library::Provider provider; - library::ConvKind conv_kind; - library::NumericTypeID element_A; - library::LayoutTypeID layout_A; - library::NumericTypeID element_B; - library::LayoutTypeID layout_B; - library::NumericTypeID element_C; - library::LayoutTypeID layout_C; - library::NumericTypeID element_accumulator; - library::NumericTypeID element_compute; - - - // - // Methods - // - - inline - ConvFunctionalKey( - library::Provider provider = library::Provider::kInvalid, - library::ConvKind conv_kind = library::ConvKind::kFprop, - library::NumericTypeID element_A = library::NumericTypeID::kF16, - library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, - library::NumericTypeID element_B = library::NumericTypeID::kF16, - library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, - library::NumericTypeID element_C = library::NumericTypeID::kF16, - library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, - library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, - library::NumericTypeID element_compute = library::NumericTypeID::kF32 - ): - provider(provider), - conv_kind(conv_kind), - element_A(element_A), - layout_A(layout_A), - element_B(element_B), - layout_B(layout_B), - element_C(element_C), - layout_C(layout_C), - element_accumulator(element_accumulator), - element_compute(element_compute) - { } - - inline - bool operator==(ConvFunctionalKey const &rhs) const { - return - (provider == rhs.provider) && - (conv_kind == rhs.conv_kind) && - (element_A == rhs.element_A) && - (layout_A == rhs.layout_A) && - (element_B == rhs.element_B) && - (layout_B == rhs.layout_B) && - (element_C == rhs.element_C) && - (layout_C == rhs.layout_C) && - (element_accumulator == rhs.element_accumulator) && - (element_compute == rhs.element_compute); - } - - inline - bool operator!=(ConvFunctionalKey const &rhs) const { - return !(*this == rhs); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// -inline -std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { - out << "{\n" - << "provider: " << to_string(key.provider) << std::endl - << "conv_kind: " << to_string(key.conv_kind) << std::endl - << "element_A: " << to_string(key.element_A) << std::endl - << "layout_A: " << to_string(key.layout_A) << std::endl - << "element_B: " << to_string(key.element_B) << std::endl - << "layout_B: " << to_string(key.layout_B) << std::endl - << "element_C: " << to_string(key.element_C) << std::endl - << "layout_C: " << to_string(key.layout_C) << std::endl - << "element_accumulator: " << to_string(key.element_accumulator) << std::endl - << "element_compute: " << to_string(key.element_compute) << std::endl - << "}"; - - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -struct ConvFunctionalKeyHasher { - using IntHash = std::hash; - - inline - static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); - } - - inline - size_t operator()(ConvFunctionalKey const &key) const { - IntHash hash; - - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.conv_kind)), 2) ^ - rotl(hash(int(key.element_A)), 3) ^ - rotl(hash(int(key.layout_A)), 4) ^ - rotl(hash(int(key.element_B)), 5) ^ - rotl(hash(int(key.layout_B)), 6) ^ - rotl(hash(int(key.element_C)), 7) ^ - rotl(hash(int(key.layout_C)), 8) ^ - rotl(hash(int(key.element_accumulator)), 9) ^ - rotl(hash(int(key.element_compute)), 10); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Establishes a partial ordering to search for Conv2d operators -struct ConvPreferenceKey { - - int compute_capability; - IteratorAlgorithmID iterator_algorithm; - - - // - // Methods - // - - ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } - - ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): - compute_capability(cc), iterator_algorithm(iterator_algorithm) { } - - bool operator<(ConvPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || - ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); - } - - bool operator==(ConvPreferenceKey const &rhs) const { - return (compute_capability == rhs.compute_capability) && - (iterator_algorithm == rhs.iterator_algorithm); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Maps minimum compute capability onto a vector of possible operations -using ConvOperationVectorMap = std::map< - ConvPreferenceKey, - std::vector ->; - -/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -using ConvOperationFunctionalMap = std::unordered_map< - ConvFunctionalKey, - ConvOperationVectorMap, - ConvFunctionalKeyHasher ->; -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Tuple uniquely identifying conv2d functional behavior -struct ReductionFunctionalKey { - library::Provider provider; - library::NumericTypeID element_workspace; - library::NumericTypeID element_accumulator; - library::NumericTypeID element_output; - library::NumericTypeID element_compute; - library::MathOperationID reduce_math_op; - library::EpilogueKind epilogue_math_op; - - - // - // Methods - // - - inline - ReductionFunctionalKey( - library::Provider provider = library::Provider::kInvalid, - library::NumericTypeID element_workspace = library::NumericTypeID::kF16, - library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, - library::NumericTypeID element_output = library::NumericTypeID::kF16, - library::NumericTypeID element_compute = library::NumericTypeID::kF32, - library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, - library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination - ): - provider(provider), - element_workspace(element_workspace), - element_accumulator(element_accumulator), - element_output(element_output), - element_compute(element_compute), - reduce_math_op(reduce_math_op), - epilogue_math_op(epilogue_math_op) - { } - - inline - bool operator==(ReductionFunctionalKey const &rhs) const { - return - (provider == rhs.provider) && - (element_workspace == rhs.element_workspace) && - (element_accumulator == rhs.element_accumulator) && - (element_output == rhs.element_output) && - (element_compute == rhs.element_compute) && - (reduce_math_op == rhs.reduce_math_op) && - (epilogue_math_op == rhs.epilogue_math_op); - } - - inline - bool operator!=(ReductionFunctionalKey const &rhs) const { - return !(*this == rhs); - } -}; - - -struct ReductionFunctionalKeyHasher { - using IntHash = std::hash; - - inline - static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); - } - - inline - size_t operator()(ReductionFunctionalKey const &key) const { - IntHash hash; - - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.element_workspace)), 2) ^ - rotl(hash(int(key.element_accumulator)), 3) ^ - rotl(hash(int(key.element_output)), 4) ^ - rotl(hash(int(key.element_compute)), 5) ^ - rotl(hash(int(key.reduce_math_op)), 6) ^ - rotl(hash(int(key.epilogue_math_op)), 7); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -inline -std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { - out << "{\n" - << "provider: " << library::to_string(key.provider) << std::endl - << "element_workspace : " << library::to_string(key.element_workspace) << std::endl - << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl - << "element_output : " << library::to_string(key.element_output) << std::endl - << "element_compute : " << library::to_string(key.element_compute) << std::endl - << "}"; - return out; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key -// i.e. only one tile size configuration per functional key -using ReductionOperationFunctionalMap = std::unordered_map< - ReductionFunctionalKey, - library::Operation const *, - ReductionFunctionalKeyHasher ->; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Table of cutlass::library::Operation instances -class OperationTable { -public: - - /// Map of all operations of type kGemm - // provider (kCUTLASS) - GemmOperationFunctionalMap gemm_operations; - - // provider (kCUTLASS, kReferenceHost, kReferenceDevice) - BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; - - // provider (kCUTLASS, kReferenceHost, kReferenceDevice) - BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations; - - /// Map of all operations of type kConv2d - // provider (kCUTLASS, kReferenceHost, kReferenceDevice) - ConvOperationFunctionalMap conv2d_operations; - - /// Map of all operations of type kConv3d - // provider (kCUTLASS, kReferenceHost, kReferenceDevice) - ConvOperationFunctionalMap conv3d_operations; - - /// Map of all operations of type kConv2d - // provider (kCUTLASS) - ReductionOperationFunctionalMap reduction_operations; - -public: - - void append(Manifest const &manifest); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h deleted file mode 100644 index 9a757433f38fbf10d9a352e07c7f3084a99e4098..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h +++ /dev/null @@ -1,68 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/operation_table.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Singleton instance stores a Manifest and Operation table -class Singleton { -public: - - /// Manifest object - Manifest manifest; - - /// Operation table referencing the Manifest - OperationTable operation_table; - -public: - - Singleton(); - - static Singleton const &get(); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h deleted file mode 100644 index 9f8c4ff13ba543b4ec63997ba55e9278bfb357a6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h +++ /dev/null @@ -1,295 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - #pragma once - - ///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Layout type identifier -enum class LayoutTypeID { - kUnknown, - kColumnMajor, - kRowMajor, - kBlockScalingTensor, - kColumnMajorInterleavedK2, - kRowMajorInterleavedK2, - kColumnMajorInterleavedK4, - kRowMajorInterleavedK4, - kColumnMajorInterleavedK16, - kRowMajorInterleavedK16, - kColumnMajorInterleavedK32, - kRowMajorInterleavedK32, - kColumnMajorInterleavedK64, - kRowMajorInterleavedK64, - kTensorNCHW, - kTensorNCDHW, - kTensorNHWC, - kTensorNDHWC, - kTensorNC32HW32, - kTensorC32RSK32, - kTensorNC64HW64, - kTensorC64RSK64, - kInvalid -}; - -/// Numeric data type -enum class NumericTypeID { - kUnknown, - kVoid, - kB1, - kU2, - kU4, - kU8, - kU16, - kU32, - kU64, - kS2, - kS4, - kS8, - kS16, - kS32, - kS64, - kFE4M3, - kFE5M2, - - kFE2M3, - kFE3M2, - kFE2M1, - kFUE8M0, - kFUE4M3, - kF8, - kF6, - kF4, - - kF16, - kBF16, - kTF32, - kF32, - kF64, - kCF16, - kCBF16, - kCF32, - kCTF32, - kCF64, - kCS2, - kCS4, - kCS8, - kCS16, - kCS32, - kCS64, - kCU2, - kCU4, - kCU8, - kCU16, - kCU32, - kCU64, - kInvalid -}; - -/// Enumerated type describing a transformation on a complex value. -enum class ComplexTransform { - kNone, - kConjugate, - kInvalid -}; - -/// Providers -enum class Provider { - kNone, - kCUTLASS, - kReferenceHost, - kReferenceDevice, - kCUBLAS, - kCUDNN, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumeration indicating the kind of operation -enum class OperationKind { - kGemm, - kBlockScaledGemm, - kBlockwiseGemm, - kRankK, - kRank2K, - kTrmm, - kSymm, - kConv2d, - kConv3d, - kEqGemm, - kSparseGemm, - kReduction, - kGroupedGemm, - kInvalid -}; - -/// Enumeration indicating whether scalars are in host or device memory -enum class ScalarPointerMode { - kHost, - kDevice, - kInvalid -}; - -/// Describes how reductions are performed across threadblocks -enum class SplitKMode { - kNone, - kSerial, - kParallel, - kParallelSerial, - kInvalid -}; - -/// Indicates the classificaition of the math instruction -enum class OpcodeClassID { - kSimt, - kTensorOp, - kWmmaTensorOp, - kSparseTensorOp, - kBlockScaledOp, - kInvalid -}; - -enum class MathOperationID { - kAdd, - kMultiplyAdd, - kMultiplyAddSaturate, - kMultiplyAddMixedInputUpcast, - kMultiplyAddFastBF16, - kMultiplyAddFastF16, - kMultiplyAddFastF32, - kMultiplyAddComplex, - kMultiplyAddComplexFastF32, - kMultiplyAddGaussianComplex, - kXorPopc, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumeration indicating what kind of GEMM operation to perform -enum class GemmKind { - kGemm, - kBlockScaledGemm, - kSparse, - kUniversal, - kPlanarComplex, - kPlanarComplexArray, - kGrouped, - kInvalid -}; - -/// Enumeration indicating what kind of RankK update operation to perform -enum class RankKKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of TRMM operation to perform -enum class TrmmKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of SYMM/HEMM operation to perform -enum class SymmKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of Conv2d operation to perform -enum class ConvKind { - kUnknown, - kFprop, - kDgrad, - kWgrad, - kInvalid -}; - -enum class ConvModeID { - kCrossCorrelation, - kConvolution, - kInvalid -}; - -// Iterator algorithm enum in order of general performance-efficiency -enum class IteratorAlgorithmID { - kNone, - kAnalytic, - kOptimized, - kFixedChannels, - kFewChannels, - kInvalid -}; - - -enum class EpilogueKind { - kUnknown, - kConversion, - kLinearCombination, - kLinearCombinationClamp, - kLinearCombinationPlanarComplex, - kLinearCombinationRelu, - kLinearCombinationSigmoid, - kInvalid -}; - - -enum class RuntimeDatatype { - kStatic, - kE4M3, - kE5M2, - kE3M2, - kE2M3, - kE2M1, - - kInvalid -}; - - -enum class RasterOrder { - kAlongN, - kAlongM, - kHeuristic, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h deleted file mode 100644 index f537421751c1f2af3b95a2e1951006af441b28e0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h +++ /dev/null @@ -1,281 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - - \brief Utilities accompanying the CUTLASS library for interacting with Library types. -*/ - -#ifndef CUTLASS_LIBRARY_UTIL_H -#define CUTLASS_LIBRARY_UTIL_H - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Lexical cast from string -template T from_string(std::string const &); - -/// Converts a Provider enumerant to a string -char const *to_string(Provider provider, bool pretty = false); - -/// Parses a Provider enumerant from a string -template <> Provider from_string(std::string const &str); - -/// Converts a GemmKind enumerant to a string -char const *to_string(GemmKind type, bool pretty = false); - -/// Converts a RankKKind enumerant to a string -char const *to_string(RankKKind type, bool pretty = false); - -/// Converts a TrmmKind enumerant to a string -char const *to_string(TrmmKind type, bool pretty = false); - -/// Converts a SymmKind enumerant to a string -char const *to_string(SymmKind type, bool pretty = false); - -/// Converts a SideMode enumerant to a string -char const *to_string(SideMode type, bool pretty = false); - -/// Converts a FillMode enumerant to a string -char const *to_string(FillMode type, bool pretty = false); - -/// Converts a BlasMode enumerant to a string -char const *to_string(BlasMode type, bool pretty = false); - -/// Converts a DiagType enumerant to a string -char const *to_string(DiagType type, bool pretty = false); - -/// Converts a NumericType enumerant to a string -char const *to_string(OperationKind type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> OperationKind from_string(std::string const &str); - -/// Converts a NumericType enumerant to a string -char const *to_string(NumericTypeID type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> NumericTypeID from_string(std::string const &str); - -/// Returns the size of a data type in bits -int sizeof_bits(NumericTypeID type); - -/// Returns true if the numeric type is a complex data type or false if real-valued. -bool is_complex_type(NumericTypeID type); - -/// Returns the real-valued type underlying a type (only different from 'type' if complex) -NumericTypeID get_real_type(NumericTypeID type); - -/// Returns true if numeric type is integer -bool is_integer_type(NumericTypeID type); - -/// Returns true if numeric type is signed -bool is_signed_type(NumericTypeID type); - -/// Returns true if numeric type is a signed integer -bool is_signed_integer(NumericTypeID type); - -/// returns true if numeric type is an unsigned integer -bool is_unsigned_integer(NumericTypeID type); - -/// Returns true if numeric type is floating-point type -bool is_float_type(NumericTypeID type); - -/// To string method for cutlass::Status -char const *to_string(Status status, bool pretty = false); - -/// Converts a LayoutTypeID enumerant to a string -char const *to_string(LayoutTypeID layout, bool pretty = false); - -/// Parses a LayoutType enumerant from a string -template <> LayoutTypeID from_string(std::string const &str); - -/// Returns the rank of a layout's stride base on the LayoutTypeID -int get_layout_stride_rank(LayoutTypeID layout_id); - -/// Converts a OpcodeClassID enumerant to a string -char const *to_string(OpcodeClassID type, bool pretty = false); - -/// Converts a OpcodeClassID enumerant from a string -template <> -OpcodeClassID from_string(std::string const &str); - -/// Converts a ComplexTransform enumerant to a string -char const *to_string(ComplexTransform type, bool pretty = false); - -/// Converts a ComplexTransform enumerant from a string -template <> -ComplexTransform from_string(std::string const &str); - - -/// Converts a SplitKMode enumerant to a string -char const *to_string(SplitKMode split_k_mode, bool pretty = false); - -/// Converts a SplitKMode enumerant from a string -template <> -SplitKMode from_string(std::string const &str); - -/// Converts a ConvModeID enumerant to a string -char const *to_string(ConvModeID type, bool pretty = false); - -/// Converts a ConvModeID enumerant from a string -template <> -ConvModeID from_string(std::string const &str); - -/// Converts a IteratorAlgorithmID enumerant to a string -char const *to_string(IteratorAlgorithmID type, bool pretty = false); - -/// Converts a IteratorAlgorithmID enumerant from a string -template <> -IteratorAlgorithmID from_string(std::string const &str); - -/// Converts a ConvKind enumerant to a string -char const *to_string(ConvKind type, bool pretty = false); - -/// Converts a ConvKind enumerant from a string -template <> -ConvKind from_string(std::string const &str); - - -/// Converts a RuntimeDatatype enumerant to a string -char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); - -/// Convers a RuntimeDatatype enumerant from a string -template<> -cutlass::library::RuntimeDatatype from_string(std::string const &str); - - -/// Converts a RasterOrder enumerant to a string -char const *to_string(RasterOrder type, bool pretty = false); - -/// Convers a RasterOrder enumerant from a string -template<> -RasterOrder from_string(std::string const &str); - -/// Converts a bool to a string -char const *to_string(bool type, bool pretty = false); - -/// Convers a bool from a string -template<> -bool from_string(std::string const &str); - -/// Lexical cast from int64_t to string -std::string lexical_cast(int64_t int_value); - -/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); - -/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -std::string lexical_cast(std::vector &bytes, NumericTypeID type); - -/// Casts from a signed int64 to the destination type. Returns true if successful. -bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); - -/// Casts from an unsigned int64 to the destination type. Returns true if successful. -bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); - -/// Casts from a real value represented as a double to the destination type. Returns true if successful. -bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); - -NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); - -#define CUDA_CHECK(call) \ - do { \ - cudaError_t err = (call); \ - if (err != cudaSuccess) { \ - std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ - << __FILE__ << ":" << __LINE__ << std::endl; \ - return Status::kInvalid; \ - } \ - } while (0) - -// RAII CUDA buffer container -class CudaBuffer { -public: - CudaBuffer() : size_(0), d_ptr_(nullptr) {} - - explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { - cudaError_t err = cudaMalloc(&d_ptr_, size_); - if (err != cudaSuccess) { - throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); - } - } - - ~CudaBuffer() { - if (d_ptr_) { - cudaFree(d_ptr_); - } - } - - CudaBuffer(CudaBuffer const&) = delete; - CudaBuffer& operator=(CudaBuffer const&) = delete; - - CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { - other.d_ptr_ = nullptr; - other.size_ = 0; - } - - CudaBuffer& operator=(CudaBuffer&& other) noexcept { - if (this != &other) { - if (d_ptr_) { - cudaFree(d_ptr_); - } - d_ptr_ = other.d_ptr_; - size_ = other.size_; - other.d_ptr_ = nullptr; - other.size_ = 0; - } - return *this; - } - - void* data() const noexcept { return d_ptr_; } - size_t size() const noexcept { return size_; } - -private: - size_t size_; - void* d_ptr_; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#endif diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp deleted file mode 100644 index c96b9a2212b42c191551ea70da3ac3baecbed487..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp +++ /dev/null @@ -1,450 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/collective.hpp" -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "gemm_operation_3x.hpp" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::CollectiveMainloop::ElementA; - using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::CollectiveMainloop::ElementB; - using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using TiledMma = typename Operator::CollectiveMainloop::TiledMma; - constexpr static int SFVecSize = TiledMma::SFVecSize; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig; - - static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; - static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; - using ElementSFD = cute::conditional_t; - using LayoutSFD = cute::conditional_t; - - - - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; - using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; - - -private: - BlockScaledGemmDescription description_; - -public: - - /// Constructor - BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) { - description_.kind = OperationKind::kBlockScaledGemm; - description_.SFA.element = NumericTypeMap::kId; - description_.SFA.layout = LayoutTypeID::kRowMajor; - description_.SFA.alignment = 128; - description_.SFA.log_extent_range = 32; - description_.SFA.log_stride_range = 32; - - description_.SFB.element = NumericTypeMap::kId; - description_.SFB.layout = LayoutTypeID::kRowMajor; - description_.SFB.alignment = 128; - description_.SFB.log_extent_range = 32; - description_.SFB.log_stride_range = 32; - - description_.SFVecSize = SFVecSize; - - description_.SFD = make_TensorDescription(128); - description_.EpilogueSFVecSize = SFD_VectorSize; - - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.gemm_kind = GemmKind::kUniversal; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { - description_.tile_description.cluster_shape = make_Coord( - Operator::ClusterShape::kM, - Operator::ClusterShape::kN, - Operator::ClusterShape::kK); - } - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::WarpCount::kM, - Operator::WarpCount::kN, - Operator::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.D = make_TensorDescription(Operator::kAlignmentD); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - /// Returns the description of the GEMM operation - BlockScaledGemmDescription const& get_gemm_description() const { - return description_; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { - // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides - // Do nothing here and construct kernel arguments in update_arguments_ instead - // We also cannot construct TMA descriptors without all the arguments available - - operator_args.mode = configuration->mode; - return Status::kSuccess; - } - - template - struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { - - if constexpr (epilogue_scalefactor_generation) { - fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); - fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); - } - - - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } - }; - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - BlockScaledGemmArguments const *arguments) { - Status status = Status::kSuccess; - - status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, *arguments); - if (status != Status::kSuccess) { - return status; - } - - operator_args.problem_shape = cute::make_shape( - arguments->problem_size.m(), - arguments->problem_size.n(), - arguments->problem_size.k(), - arguments->batch_count); - - // update arguments - - if constexpr (IsRuntimeDataType) { - using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - - using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; - using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; - - static_assert(cute::is_same_v, - "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); - using RuntimeDatatypeArg = RuntimeDataTypeA; - - auto mapping = [](RuntimeDatatype type) { - if constexpr (cute::is_same_v) { - if (type == RuntimeDatatype::kE3M2) { - return cute::UMMA::MXF8F6F4Format::E3M2; - } else if (type == RuntimeDatatype::kE2M3) { - return cute::UMMA::MXF8F6F4Format::E2M3; - } else if (type == RuntimeDatatype::kE2M1) { - return cute::UMMA::MXF8F6F4Format::E2M1; - } else { - assert("Invalid input datatype."); - } - } - else if constexpr (cute::is_same_v) { - if (type == RuntimeDatatype::kE2M1) { - return cute::UMMA::MXF4Format::E2M1; - } else { - assert("Invalid input datatype."); - } - } - // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype - CUTE_GCC_UNREACHABLE; - }; - - operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); - operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); - - } - else { - - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - } - operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); - operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); - operator_args.epilogue.ptr_C = static_cast(arguments->C); - operator_args.epilogue.ptr_D = static_cast(arguments->D); - - operator_args.mainloop.dA = cute::make_int_tuple_from( - arguments->lda, arguments->batch_stride_A); - operator_args.mainloop.dB = cute::make_int_tuple_from( - arguments->ldb, arguments->batch_stride_B); - operator_args.epilogue.dC = cute::make_int_tuple_from( - arguments->ldc, arguments->batch_stride_C); - operator_args.epilogue.dD = operator_args.epilogue.dC; - - operator_args.mainloop.layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); - operator_args.mainloop.layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); - - /* Query device SM count to pass onto the kernel as an argument, where needed */ - operator_args.hw_info.sm_count = arguments->sm_count; - if constexpr (!std::is_const_v) { - operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; - } - - if constexpr (!std::is_const_v) { - using Enum_t = decltype(operator_args.scheduler.raster_order); - switch (arguments->raster_order) { - case RasterOrder::kAlongN: - operator_args.scheduler.raster_order = Enum_t::AlongN; - break; - case RasterOrder::kAlongM: - operator_args.scheduler.raster_order = Enum_t::AlongM; - break; - default: - operator_args.scheduler.raster_order = Enum_t::Heuristic; - } - } - - if constexpr (std::is_same_v) { - operator_args.scheduler.splits = arguments->split_k_slices; - } - - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { - operator_args.hw_info.cluster_shape = dim3( - arguments->cluster_shape.m(), - arguments->cluster_shape.n(), - arguments->cluster_shape.k()); - operator_args.hw_info.cluster_shape_fallback = dim3( - arguments->cluster_shape_fallback.m(), - arguments->cluster_shape_fallback.n(), - arguments->cluster_shape_fallback.k()); - } - - return status; - } - -public: - - /// Returns success if the operation can proceed - Status can_implement( - void const *configuration_ptr, void const *arguments_ptr) const override { - - GemmUniversalConfiguration const *configuration = - static_cast(configuration_ptr); - BlockScaledGemmArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - // can_implement rules may need access to problem shape - args.problem_shape = cute::make_shape( - configuration->problem_size.m(), - configuration->problem_size.n(), - configuration->problem_size.k(), - configuration->batch_count); - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - uint64_t get_host_workspace_size(void const *configuration) const override { - return sizeof(Operator); - } - - /// Gets the device-side workspace - uint64_t get_device_workspace_size( - void const *configuration_ptr,void const *arguments_ptr) const override { - - OperatorArguments args; - auto status = update_arguments_( - args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const override { - Operator *op = new (host_workspace) Operator; - return Status::kSuccess; - } - - Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, - uint8_t **profiler_workspaces, - int problem_count_from_profiler, - cudaStream_t stream = nullptr) { - return Status::kSuccess; - } - - /// Runs the kernel - Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments args; - Status status = update_arguments_(args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); - return status; - } -}; -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::library - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp deleted file mode 100644 index 00347a993e29035e58401e69698267045b399f7d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp +++ /dev/null @@ -1,429 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/collective.hpp" -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "gemm_operation_3x.hpp" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::CollectiveMainloop::ElementA; - using ElementSFA = typename Operator::ElementAccumulator; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::CollectiveMainloop::ElementB; - using ElementSFB = typename Operator::ElementAccumulator; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using TiledMma = typename Operator::CollectiveMainloop::TiledMma; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - -private: - BlockwiseGemmDescription description_; - -public: - - /// Constructor - BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) { - description_.kind = OperationKind::kBlockwiseGemm; - description_.SFA.element = NumericTypeMap::kId; - description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ? - LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; - description_.SFA.alignment = CollectiveMainloop::AlignmentSFA; - description_.SFA.log_extent_range = 32; - description_.SFA.log_stride_range = 32; - - description_.SFB.element = NumericTypeMap::kId; - description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ? - LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; - description_.SFB.alignment = CollectiveMainloop::AlignmentSFA; - description_.SFB.log_extent_range = 32; - description_.SFB.log_stride_range = 32; - - description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; - description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; - description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.gemm_kind = GemmKind::kUniversal; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { - description_.tile_description.cluster_shape = make_Coord( - Operator::ClusterShape::kM, - Operator::ClusterShape::kN, - Operator::ClusterShape::kK); - } - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::WarpCount::kM, - Operator::WarpCount::kN, - Operator::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.D = make_TensorDescription(Operator::kAlignmentD); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - /// Returns the description of the GEMM operation - BlockwiseGemmDescription const& get_gemm_description() const { - return description_; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { - // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides - // Do nothing here and construct kernel arguments in update_arguments_ instead - // We also cannot construct TMA descriptors without all the arguments available - - operator_args.mode = configuration->mode; - return Status::kSuccess; - } - - template - struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) { - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } - }; - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - BlockwiseGemmArguments const *arguments) { - Status status = Status::kSuccess; - - status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, *arguments); - if (status != Status::kSuccess) { - return status; - } - - operator_args.problem_shape = cute::make_shape( - arguments->problem_size.m(), - arguments->problem_size.n(), - arguments->problem_size.k(), - arguments->batch_count); - - // update arguments - - if constexpr (IsRuntimeDataType) { - using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); - - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); - } - - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); - } - - } - else { - - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - } - operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); - operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); - operator_args.epilogue.ptr_C = static_cast(arguments->C); - operator_args.epilogue.ptr_D = static_cast(arguments->D); - - operator_args.mainloop.dA = cute::make_int_tuple_from( - arguments->lda, arguments->batch_stride_A); - operator_args.mainloop.dB = cute::make_int_tuple_from( - arguments->ldb, arguments->batch_stride_B); - operator_args.epilogue.dC = cute::make_int_tuple_from( - arguments->ldc, arguments->batch_stride_C); - operator_args.epilogue.dD = operator_args.epilogue.dC; - - operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); - operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); - - /* Query device SM count to pass onto the kernel as an argument, where needed */ - operator_args.hw_info.sm_count = arguments->sm_count; - if constexpr (!std::is_const_v) { - operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; - } - - if constexpr (!std::is_const_v) { - using Enum_t = decltype(operator_args.scheduler.raster_order); - switch (arguments->raster_order) { - case RasterOrder::kAlongN: - operator_args.scheduler.raster_order = Enum_t::AlongN; - break; - case RasterOrder::kAlongM: - operator_args.scheduler.raster_order = Enum_t::AlongM; - break; - default: - operator_args.scheduler.raster_order = Enum_t::Heuristic; - } - } - - if constexpr (std::is_same_v) { - operator_args.scheduler.splits = arguments->split_k_slices; - } - - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { - operator_args.hw_info.cluster_shape = dim3( - arguments->cluster_shape.m(), - arguments->cluster_shape.n(), - arguments->cluster_shape.k()); - operator_args.hw_info.cluster_shape_fallback = dim3( - arguments->cluster_shape_fallback.m(), - arguments->cluster_shape_fallback.n(), - arguments->cluster_shape_fallback.k()); - } - - return status; - } - -public: - - /// Returns success if the operation can proceed - Status can_implement( - void const *configuration_ptr, void const *arguments_ptr) const override { - - GemmUniversalConfiguration const *configuration = - static_cast(configuration_ptr); - BlockwiseGemmArguments const *arguments = - static_cast(arguments_ptr); - - if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) { - return Status::kErrorInvalidProblem; - } - if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) { - return Status::kErrorInvalidProblem; - } - if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) { - return Status::kErrorInvalidProblem; - } - - OperatorArguments args; - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - // can_implement rules may need access to problem shape - args.problem_shape = cute::make_shape( - configuration->problem_size.m(), - configuration->problem_size.n(), - configuration->problem_size.k(), - configuration->batch_count); - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - uint64_t get_host_workspace_size(void const *configuration) const override { - return sizeof(Operator); - } - - /// Gets the device-side workspace - uint64_t get_device_workspace_size( - void const *configuration_ptr,void const *arguments_ptr) const override { - - OperatorArguments args; - auto status = update_arguments_( - args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const override { - Operator *op = new (host_workspace) Operator; - return Status::kSuccess; - } - - Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, - uint8_t **profiler_workspaces, - int problem_count_from_profiler, - cudaStream_t stream = nullptr) { - return Status::kSuccess; - } - - /// Runs the kernel - Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments args; - Status status = update_arguments_(args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); - return status; - } -}; -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::library - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h deleted file mode 100644 index 3b1a1584db92c4379e04c84a2658f79313b3eaad..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h +++ /dev/null @@ -1,650 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all CONV operation kinds in CUTLASS Library. -*/ - -#pragma once -#include -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -#include "cutlass/conv/kernel/default_depthwise_fprop.h" -#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/conv/device/direct_convolution.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/util/host_tensor.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/core_io.h" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class Conv2dOperationBase : public Operation { -public: - - using Operator = Operator_; - - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - ConvDescription description_; - -public: - - /// Constructor - Conv2dOperationBase(char const *name = "unknown_conv2d") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kConv2d; - description_.conv_dim = Operator::kConvDim; - - description_.iterator_algorithm = IteratorAlgorithmMap::kId; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::UnderlyingKernel::WarpCount::kM, - Operator::UnderlyingKernel::WarpCount::kN, - Operator::UnderlyingKernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.C = make_TensorDescription(); - description_.element_epilogue = NumericTypeMap::kId; - - // TODO: Add split k mode Serial and parallel to convolutions - // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; - - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Conv2d library operation class for cutlass profiler -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template -class Conv2dOperation : public Conv2dOperationBase { -public: - - using Operator = Operator_; - - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - using OperatorArguments = typename Operator::Arguments; - -public: - /// Constructor - Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { - this->description_.conv_kind = ConvKindMap::kId; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - Conv2dConfiguration const *configuration) { - - - operator_args.problem_size = configuration->problem_size; - - operator_args.ref_A = - { - nullptr, - LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_B = - { - nullptr, - LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_C = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_D = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.split_k_mode = configuration->split_k_mode; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - ConvArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else { - return Status::kErrorInvalidProblem; - } - - operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); - operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); - operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); - operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - Conv2dConfiguration const *configuration = - static_cast(configuration_ptr); - - ConvArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - //std::cout << "initialize library::Conv2dOperation" << std::endl; - //print_operator_args(args); - return op->initialize(args, device_workspace, stream); - - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - //std::cout << "run library::Conv2dOperation" << std::endl; - //print_operator_args(args); - return op->run(stream); - } - - /// Call print_operator_args from the Conv2dOperation::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "Conv2dOperation::OperatorArguments" << std::endl - << " problem_size:" << std::endl - << operator_args.problem_size << std::endl - << " split_k_mode: " - << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl - << " epilogue (alpha, beta): " - << operator_args.output_op.alpha << ", " - << operator_args.output_op.beta << std::endl - << " ref_A (ptr, {stride}): " - << operator_args.ref_A.data() << ", {" - << operator_args.ref_A.stride(0) << ", " - << operator_args.ref_A.stride(1) << ", " - << operator_args.ref_A.stride(2) << "}" << std::endl - << " ref_B (ptr, {stride}): " - << operator_args.ref_B.data() << ", {" - << operator_args.ref_B.stride(0) << ", " - << operator_args.ref_B.stride(1) << ", " - << operator_args.ref_B.stride(2) << "}" << std::endl - << " ref_C (ptr, {stride}): " - << operator_args.ref_C.data() << ", {" - << operator_args.ref_C.stride(0) << ", " - << operator_args.ref_C.stride(1) << ", " - << operator_args.ref_C.stride(2) << "}" << std::endl - << " ref_D (ptr, {stride}): " - << operator_args.ref_D.data() << ", {" - << operator_args.ref_D.stride(0) << ", " - << operator_args.ref_D.stride(1) << ", " - << operator_args.ref_D.stride(2) << "}" << std::endl; - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// DirectConv2d library operation class for cutlass profiler -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class DirectConv2dOperation : public Conv2dOperation { -public: - - using Operator = Operator_; - using Base = Conv2dOperation; - - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - using OperatorArguments = typename Operator::Arguments; - -public: - /// Constructor - DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { - this->description_.conv_kind = ConvKindMap::kId; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - Conv2dConfiguration const *configuration) { - - - operator_args.problem_size = configuration->problem_size; - - operator_args.ref_A = - { - nullptr, - LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_B = - { - nullptr, - LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_reordered_B = - { - nullptr, - LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_C = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_D = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.split_k_mode = configuration->split_k_mode; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - ConvArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else { - return Status::kErrorInvalidProblem; - } - - operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); - operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); - operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); - operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); - operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - Conv2dConfiguration const *configuration = - static_cast(configuration_ptr); - - ConvArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - //std::cout << "initialize library::Conv2dOperation" << std::endl; - //print_operator_args(args); - return op->initialize(args, device_workspace, stream); - - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - //std::cout << "run library::Conv2dOperation" << std::endl; - //print_operator_args(args); - return op->run(stream); - } - - /// Call print_operator_args from the Conv2dOperation::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "Conv2dOperation::OperatorArguments" << std::endl - << " problem_size:" << std::endl - << operator_args.problem_size << std::endl - << " split_k_mode: " - << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl - << " epilogue (alpha, beta): " - << operator_args.output_op.alpha << ", " - << operator_args.output_op.beta << std::endl - << " ref_A (ptr, {stride}): " - << operator_args.ref_A.data() << ", {" - << operator_args.ref_A.stride(0) << ", " - << operator_args.ref_A.stride(1) << ", " - << operator_args.ref_A.stride(2) << "}" << std::endl - << " ref_B (ptr, {stride}): " - << operator_args.ref_B.data() << ", {" - << operator_args.ref_B.stride(0) << ", " - << operator_args.ref_B.stride(1) << ", " - << operator_args.ref_B.stride(2) << "}" << std::endl - << " ref_C (ptr, {stride}): " - << operator_args.ref_C.data() << ", {" - << operator_args.ref_C.stride(0) << ", " - << operator_args.ref_C.stride(1) << ", " - << operator_args.ref_C.stride(2) << "}" << std::endl - << " ref_D (ptr, {stride}): " - << operator_args.ref_D.data() << ", {" - << operator_args.ref_D.stride(0) << ", " - << operator_args.ref_D.stride(1) << ", " - << operator_args.ref_D.stride(2) << "}" << std::endl; - } -}; - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h deleted file mode 100644 index fe402c4494c27a882bf42f867a708e954ee87dc0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h +++ /dev/null @@ -1,389 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all CONV operation kinds in CUTLASS Library. -*/ - -#pragma once -#include -#include "cutlass/cutlass.h" -#include "cutlass/conv/kernel/default_conv3d_fprop.h" -#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -#include "cutlass/conv/device/implicit_gemm_convolution.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/util/host_tensor.h" - -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/core_io.h" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class Conv3dOperationBase : public Operation { -public: - - using Operator = Operator_; - - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - ConvDescription description_; - -public: - - /// Constructor - Conv3dOperationBase(char const *name = "unknown_conv3d") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kConv3d; - description_.conv_dim = Operator::kConvDim; - - description_.iterator_algorithm = IteratorAlgorithmMap::kId; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::UnderlyingKernel::WarpCount::kM, - Operator::UnderlyingKernel::WarpCount::kN, - Operator::UnderlyingKernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.C = make_TensorDescription(); - description_.element_epilogue = NumericTypeMap::kId; - - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Conv2d library operation class for cutlass profiler -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -template -class Conv3dOperation : public Conv3dOperationBase { -public: - - using Operator = Operator_; - - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - using OperatorArguments = typename Operator::Arguments; - -public: - /// Constructor - Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { - this->description_.conv_kind = ConvKindMap::kId; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - Conv3dConfiguration const *configuration) { - - - operator_args.problem_size = configuration->problem_size; - - operator_args.ref_A = - { - nullptr, - LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_B = - { - nullptr, - LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_C = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.ref_D = - { - nullptr, - LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) - }; - - operator_args.split_k_mode = configuration->split_k_mode; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - ConvArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.output_op = params; - } - else { - return Status::kErrorInvalidProblem; - } - - operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); - operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); - operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); - operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - Conv3dConfiguration const *configuration = - static_cast(configuration_ptr); - - ConvArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - //std::cout << "initialize library::Conv3dOperation" << std::endl; - //print_operator_args(args); - return op->initialize(args, device_workspace, stream); - - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - //std::cout << "run library::Conv3dOperation" << std::endl; - //print_operator_args(args); - return op->run(stream); - } - - /// Call print_operator_args from the Conv3dOperation::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "Conv3dOperation::OperatorArguments" << std::endl - << " problem_size: " - << operator_args.problem_size << std::endl - << " split_k_mode: " - << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl - << " epilogue (alpha, beta): " - << operator_args.output_op.alpha << ", " - << operator_args.output_op.beta << std::endl - << " ref_A (ptr, {stride}): " - << operator_args.ref_A.data() << ", {" - << operator_args.ref_A.stride(0) << ", " - << operator_args.ref_A.stride(1) << ", " - << operator_args.ref_A.stride(2) << ", " - << operator_args.ref_A.stride(3) << "}" << std::endl - << " ref_B (ptr, {stride}): " - << operator_args.ref_B.data() << ", {" - << operator_args.ref_B.stride(0) << ", " - << operator_args.ref_B.stride(1) << ", " - << operator_args.ref_B.stride(2) << ", " - << operator_args.ref_B.stride(3) << "}" << std::endl - << " ref_C (ptr, {stride}): " - << operator_args.ref_C.data() << ", {" - << operator_args.ref_C.stride(0) << ", " - << operator_args.ref_C.stride(1) << ", " - << operator_args.ref_C.stride(2) << ", " - << operator_args.ref_C.stride(3) << "}" << std::endl - << " ref_D (ptr, {stride}): " - << operator_args.ref_D.data() << ", {" - << operator_args.ref_D.stride(0) << ", " - << operator_args.ref_D.stride(1) << ", " - << operator_args.ref_D.stride(2) << ", " - << operator_args.ref_D.stride(3) << "}" << std::endl; - } -}; - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp deleted file mode 100644 index 86c1513e9c934c22e281cf37e1c5e7783e23d305..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp +++ /dev/null @@ -1,980 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all CONV operation kinds in CUTLASS Library. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/conv/convnd_problem_shape.hpp" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/trace.h" -#include -#include -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) -#include -#endif - -namespace cutlass::library { - -namespace detail { - -template -constexpr cute::array -vector_to_array_strides_helper(const std::vector& v, - std::index_sequence) -{ - return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)}; -} - -template -cute::array -vector_to_array_strides(const std::vector& v, std::integral_constant) -{ - static_assert(Size != 0); - CUTLASS_ASSERT(v.size() + 1u == Size); - return vector_to_array_strides_helper(v, std::make_index_sequence{}); -} - -template -constexpr cute::array -coord_to_array_strides_helper( - const ::cutlass::Coord coord, - std::index_sequence) -{ - return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)}; -} - -template -cute::array -coord_to_array_strides(const ::cutlass::Coord& coord) -{ - static_assert(Rank >= 0); - return coord_to_array_strides_helper(coord, std::make_index_sequence{}); -} - -} // namespace detail - -// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions. -// For CUTLASS 2's 2-D convolutions, see Conv2dOperation. -// For CUTLASS 2's 3-D convolutions, see Conv3dOperation. -template -class ConvOperation3x : public Operation { -public: - using Operator = Operator_; - - static_assert(Operator::NumSpatialDimensions == 2 || - Operator::NumSpatialDimensions == 3, - "The profiler currently only supports convolutions with 2 or 3 spatial dimensions."); - using LayoutA = cute::conditional_t - >; - using LayoutB = LayoutA; - using LayoutC = LayoutA; - - using ElementA = typename Operator::ElementA; - using ElementB = typename Operator::ElementB; - using ElementC = typename Operator::ElementC; - using ElementD = typename Operator::ElementD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; - - ConvOperation3x(const char* name = "unknown_cutlass_3_conv") { - // Initialize OperationDescription (the base class) - description_.name = name; - description_.provider = Provider::kCUTLASS; - - if constexpr (Operator::NumSpatialDimensions == 2) { - description_.kind = OperationKind::kConv2d; - } - else if constexpr (Operator::NumSpatialDimensions == 3) { - description_.kind = OperationKind::kConv3d; - } - else { - static_assert(::cutlass::detail::dependent_false, - "This class currently only supports 2-D and 3-D convolutions."); - } - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::WarpCount::kM, - Operator::WarpCount::kN, - Operator::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationID::kMultiplyAdd; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - // Initialize ConvDescription (the subclass) - - // kConvDim does not exist in Operator for CUTLASS 3 convolutions. - // For CUTLASS 2 convolutions, it is the number of spatial dimensions. - description_.conv_dim = Operator::NumSpatialDimensions; - description_.conv_kind = ConvKindMap::kId; - - description_.iterator_algorithm = {}; - - description_.A = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.C = make_TensorDescription(); - description_.element_epilogue = NumericTypeMap::kId; - } - - ~ConvOperation3x() override = default; - - OperationDescription const& description() const override { - return static_cast(description_); - } - -private: - Status update_operator_arguments_from_configuration_2d_or_3d( - typename Operator::Arguments& out_args, - void const* configuration) const { - Status status = Status::kInvalid; - - CUTLASS_ASSERT(configuration != nullptr); - - if constexpr (Operator::NumSpatialDimensions == 2) { - CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); - // tools/library/include/cutlass/library/library.h - // defines Conv2dConfiguration. - // tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h - // uses Conv2dConfiguration. - auto* conf_ptr = reinterpret_cast(configuration); - status = update_operator_arguments_from_configuration(out_args, *conf_ptr); - } - else if constexpr (Operator::NumSpatialDimensions == 3) { - CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); - auto* conf_ptr = reinterpret_cast(configuration); - status = update_operator_arguments_from_configuration(out_args, *conf_ptr); - } - else { - static_assert(::cutlass::detail::dependent_false, - "This class currently only supports 2-D and 3-D convolutions."); - } - - return status; - } - -public: - Status can_implement( - void const* configuration, - void const* arguments) const override { - Status status = Status::kInvalid; - - // gemm_operation_3x.hpp accesses "configuration" as - // GemmUniversalConfiguration (which lives in - // tools/library/include/cutlass/library/library.h) and - // "arguments" as GemmUniversalArguments (which lives in - // tools/library/include/cutlass/library/library.h). - // Those things don't apply to convolutions. - // Despite the existence of ConvUniversal, there's no - // corresponding "ConvUniversalConfiguration" or - // "ConvUniversalArguments." - - CUTLASS_ASSERT(configuration != nullptr); - CUTLASS_ASSERT(arguments != nullptr); - - typename Operator::Arguments out_args{}; - status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); - if (status != Status::kSuccess) { - CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); - return status; - } - - auto* in_args_ptr = reinterpret_cast(arguments); - status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); - if (status != Status::kSuccess) { - CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); - return status; - } - - return Operator::can_implement(out_args); - } - - uint64_t get_host_workspace_size(void const* /* configuration */) const override { - return sizeof(Operator); - } - - uint64_t get_device_workspace_size( - void const* configuration, - void const* arguments = nullptr) const override - { - // This presumes that at least one of configuration or arguments is nonnull. - Status status = Status::kInvalid; - - // gemm_operation_3x.hpp has get_device_workspace_size return 0 on - // error. It's not clear that this is what we want -- perhaps we - // should return something like expected? -- but - // it's the only option that preserves the current interface. - constexpr uint64_t error_indication = 0; - - typename Operator::Arguments out_args{}; - if (configuration != nullptr) { - status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); - if (status != Status::kSuccess) { - return error_indication; - } - } - if (arguments != nullptr) { - auto* in_args_ptr = reinterpret_cast(arguments); - status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); - if (status != Status::kSuccess) { - return error_indication; - } - } - - if (status == Status::kSuccess) { - return static_cast(Operator::get_workspace_size(out_args)); - } - else { - return error_indication; - } - } - - Status initialize( - void const* configuration, - void* host_workspace, - void* /* device_workspace */ = nullptr, - cudaStream_t stream = nullptr) const override - { - Status status = Status::kInvalid; - - if (configuration == nullptr) { - CUTLASS_TRACE_HOST("Input configuration is null."); - return Status::kInvalid; - } - - typename Operator::Arguments out_args{}; - status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); - if (status != Status::kSuccess) { - // Any kind of failure invalidates the last successful configuration. - clear_last_successful_config(); - return status; - } - else { - set_last_successful_config(configuration); - } - - if (host_workspace == nullptr) { - CUTLASS_TRACE_HOST("host_workspace is null."); - return Status::kInvalid; - } - (void) new (host_workspace) Operator; - return status; - - // CUTLASS 2 convolutions call the Operator's initialize function - // here, like this. - // - //return op->initialize(args, device_workspace, stream); - // - // CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms - // (GemmUniversal), lack an "initialize" member function. - } - - Status run( - void const* arguments, - void* host_workspace, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override - { - auto status = Status::kInvalid; - - // The Operator doesn't appear to save the last configuration (it - // doesn't have a way to do that, since it lacks an initialize() - // member function), so we have to use the stored configuration - // from the last successful initialize() call (if any). - typename Operator::Arguments out_args{}; - status = update_operator_arguments_from_stored_configuration(out_args); - if (status != Status::kSuccess) { - CUTLASS_TRACE_HOST("Updating from previous successful configuration failed."); - return status; - } - - if (arguments == nullptr) { - CUTLASS_TRACE_HOST("Input argument 'arguments' is null."); - return Status::kInvalid; - } - auto* in_args_ptr = reinterpret_cast(arguments); - status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); - if (status != Status::kSuccess) { - return status; - } - - auto* op = reinterpret_cast(host_workspace); - return op->run(out_args, device_workspace, stream, nullptr, in_args_ptr->use_pdl); - } - -private: - ConvDescription description_; - // Result of initialize() calling - // update_operator_arguments_from_configuration() successfully. - // This is needed because run() doesn't take a configuration, just - // arguments, and the kernel doesn't appear to save the - // configuration from the last initialize() call. - // - // Unfortunately, this must be declared mutable, because it must be - // set in initialize(), and initialize() is inherited as const. - mutable std::variant< - std::monostate, - Conv2dConfiguration, - Conv3dConfiguration> last_successful_config_{std::monostate{}}; - - // Clear the last configuration resulting from a successful initialize() call. - // - // Unfortunately, this must be declared const, because initialize() is. - void clear_last_successful_config() const { - last_successful_config_ = std::monostate{}; - } - - // Set the last configuration resulting from a successful initialize() call. - // - // Unfortunately, this must be declared const, because initialize() is. - void set_last_successful_config(void const* configuration) const { - CUTLASS_ASSERT(configuration != nullptr); - - if constexpr (Operator::NumSpatialDimensions == 2) { - CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); - auto* conf_ptr = reinterpret_cast(configuration); - last_successful_config_ = *conf_ptr; - } else if constexpr (Operator::NumSpatialDimensions == 3) { - CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); - auto* conf_ptr = reinterpret_cast(configuration); - last_successful_config_ = *conf_ptr; - } - else { - static_assert(::cutlass::detail::dependent_false, - "This class currently only supports 2-D and 3-D convolutions."); - } - } - - // Whether a configuration from a successful initialize() call exists. - bool last_successful_config_exists() const { - return not std::holds_alternative(last_successful_config_); - } - - // Visitor for update_operator_arguments_from_stored_configuration. - struct ConfigurationVisitor { - typename Operator::Arguments& out_args; - - Status operator() (std::monostate const&) const { - CUTLASS_TRACE_HOST("No successful previous configuration exists. " - "One cause is calling run() before a successful initialize() call."); - return Status::kInvalid; - } - Status operator() (Conv2dConfiguration const& conf2d) const { - return update_operator_arguments_from_configuration(out_args, conf2d); - } - Status operator() (Conv3dConfiguration const& conf3d) const { - return update_operator_arguments_from_configuration(out_args, conf3d); - } - }; - - // Like update_operator_arguments_from_configuration, but on the - // stored configuration from the last successful initialize() call, - // if any. If there was no last successful initialize() call, - // then return Status::kInvalid. - // - // Unfortunately, this must be declared const, because run() is. - Status update_operator_arguments_from_stored_configuration( - typename Operator::Arguments& out_args) const - { - return std::visit(ConfigurationVisitor{out_args}, last_successful_config_); - } - - template - struct UpdateFusionArgs { - static Status update_( - FusionArgs const&, - ConvArguments const&) - { - // For custom EVT, it is the user's responsibility to ensure - // that alpha and beta are updated appropriately. - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_( - FusionArgs& fusion_args, - ConvArguments const& arguments) - { - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } - }; - - static Status update_operator_arguments_from_configuration( - typename Operator::Arguments& out_args, - Conv2dConfiguration const& config) - { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("ConvOperator3x::" - "update_operator_arguments_from_configuration" - "(Conv2dConfiguration)\n"); -#endif - using detail::vector_to_array_strides; - - constexpr int num_spatial_dims = Operator::NumSpatialDimensions; - if constexpr (num_spatial_dims != 2) { - CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration " - "with an Operator whose NumSpatialDimensions is exactly 2."); - return Status::kInvalid; - } - else { - // Convolutions split the metadata (in Conv2dConfiguration) from - // the data (ConvArguments, which only has pointers and a single - // enum value). Thus, this class will need both the - // configuration and the (user's input) arguments to set up the - // kernel's arguments. This function can fill in what the - // configuration has now, but the class will need the user's - // input arguments later. - if (config.split_k_mode != conv::SplitKMode::kSerial) { - CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); - return Status::kInvalid; - } - // config.problem_size.split_k_slices is only meaningful if - // split_k_mode != kSerial. If this code later supports other - // split_k_mode values, then it will also need to read - // split_k_slices. - - const int N = config.problem_size.N; - const int H = config.problem_size.H; - const int W = config.problem_size.W; - const int C = config.problem_size.C; - const int K = config.problem_size.K; - const int R = config.problem_size.R; - const int S = config.problem_size.S; - const int pad_h = config.problem_size.pad_h; - const int pad_w = config.problem_size.pad_w; - const int traversal_stride_h = config.problem_size.stride_h; - const int traversal_stride_w = config.problem_size.stride_w; - const int dilation_h = config.problem_size.dilation_h; - const int dilation_w = config.problem_size.dilation_w; - - // CUTLASS 3's implicit GEMM convolution kernels currently only - // support cross correlation (passing over the activation and - // filter tensors in the same order). The convolution mode is - // future work. - const auto mode = config.problem_size.mode; - if (mode != cutlass::conv::Mode::kCrossCorrelation) { - CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " - "are not currently supported."); - return Status::kInvalid; - } - - constexpr int num_spatial_dims = Operator::NumSpatialDimensions; - constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; - constexpr auto the_stride_size = std::integral_constant{}; - -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" - << " stride_size = " << stride_size << "\n"; - auto print_stride = [] (auto const& stride, char const variable_name[]) { - std::cerr << " " << variable_name << ": ["; - for (size_t k = 0; k < stride.size(); ++k) { - std::cerr << stride[k]; - if (k + 1u < stride.size()) { - std::cerr << ", "; - } - } - std::cerr << "]\n"; - }; - print_stride(config.stride_a, "config.stride_a"); - print_stride(config.stride_b, "config.stride_b"); - print_stride(config.stride_c, "config.stride_c"); -#endif - - // Conv2dConfiguration stores the strides as std::vector, - // so the code needs to check the run-time vector lengths. - if (config.stride_a.size() + 1u != stride_size) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) - std::ostringstream os; - os << "config.stride_a.size() + 1u = " - << (config.stride_a.size() + 1u) - << " != num_spatial_dims + 2u = " << stride_size; - CUTLASS_TRACE_HOST( os.str() ); -#endif - return Status::kInvalid; - } - if (config.stride_b.size() + 1u != stride_size) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) - std::ostringstream os; - os << "config.stride_b.size() + 1u = " - << (config.stride_b.size() + 1u) - << " != num_spatial_dims + 2u = " << stride_size; - CUTLASS_TRACE_HOST( os.str() ); -#endif - return Status::kInvalid; - } - if (config.stride_c.size() + 1u != stride_size) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) - std::ostringstream os; - os << "config.stride_c.size() + 1u = " - << (config.stride_c.size() + 1u) - << " != num_spatial_dims + 2u = " << stride_size; - CUTLASS_TRACE_HOST( os.str() ); -#endif - return Status::kInvalid; - } - - constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; - using problem_shape_type = - cutlass::conv::ConvProblemShape; - // cute::array; must convert to the kernel's native strides - using TensorStride = typename problem_shape_type::TensorStride; - - const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); - const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); - const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); - - // cutlass::library::Conv2dConfiguration has no member stride_d. - // The code below imitates the testbed, - // which just sets D's strides to C's strides. - - const int num_groups = config.problem_size.groups; - if (num_groups != 1) { - CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); - return Status::kInvalid; - } - // ConvProblemShape is how CUTLASS 3 kernels represent - // convolution problems. ConvProblemShape's constructors take - // shape_act, stride_act, shape_flt, and stride_flt, and set - // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C - // according to Fprop / Dgrad / Wgrad. - // - // This means that stride_act isn't always config.stride_A, - // depending on Fprop / Dgrad / Wgrad. The code here "undoes" - // the logic in Conv2dWorkspace::set_stride_vector so that we - // can recover the strides of the activation and filter tensors. - // It doesn't need to worry about the so-called "output" tensor - // (which might not be C), as ConvProblemShape's constructor - // figures out its shapes and strides. - using TensorExtent = typename problem_shape_type::TensorExtent; - TensorExtent shape_act{N, H, W, C}; - auto stride_act = [&] () { - // Some compilers consider conv_op (defined above), as - // captured by this lambda, as "not a constant expression." - constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; - if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { - return stride_A; - } - else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { - return stride_C; - } - else { // conv_kind == cutlass::conv::Operator::kWgrad - return stride_B; - } - } (); - TensorExtent shape_flt{K, R, S, C}; - auto stride_flt = [&] () { - // Some compilers consider conv_op (defined above), as - // captured by this lambda, as "not a constant expression." - constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; - if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { - return stride_B; - } - else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { - return stride_B; - } - else { // conv_kind == cutlass::conv::Operator::kWgrad - return stride_C; - } - } (); - - problem_shape_type problem_shape( - /* mode = */ mode, - /* shape_act = */ shape_act, - /* stride_act = */ stride_act, - /* shape_flt = */ shape_flt, - /* stride_flt = */ stride_flt, - /* lower_padding = */ {pad_h, pad_w}, - /* upper_padding = */ {pad_h, pad_w}, - /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, - /* dilation = */ {dilation_h, dilation_w}, - num_groups); - out_args.problem_shape = problem_shape; - - // ConvProblemShape's constructor sets its shape_C member. -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - printf("\n problem_shape.shape_C: "); - print(problem_shape.shape_C); - printf("\n problem_shape.stride_C: "); - print(problem_shape.stride_C); - printf("\n"); -#endif - // Initialization of C's and D's strides follows the CUTLASS 3 - // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). - { - using StrideC = typename Operator::ConvKernel::StrideC; - using StrideD = typename Operator::ConvKernel::StrideD; - auto stride_C = StrideC{}; - auto stride_D = StrideD{}; - - if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { - stride_C = cutlass::make_cute_packed_stride( - StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); - stride_D = cutlass::make_cute_packed_stride( - StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; -#endif - } - else { - cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " - << stride_C_i << "\n"; -#endif - cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - }); - cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " - << stride_D_i << "\n"; -#endif - cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - }); - } - out_args.epilogue.dC = stride_C; - out_args.epilogue.dD = stride_D; - } - return Status::kSuccess; - } - } - - static Status update_operator_arguments_from_configuration( - typename Operator::Arguments& out_args, - Conv3dConfiguration const& config) - { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("ConvOperator3x::" - "update_operator_arguments_from_configuration" - "(Conv3dConfiguration)\n"); -#endif - using detail::coord_to_array_strides; - - constexpr int num_spatial_dims = Operator::NumSpatialDimensions; - if constexpr (num_spatial_dims != 3) { - CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration " - "with an Operator whose NumSpatialDimensions is exactly 3."); - return Status::kInvalid; - } - else { - // Convolutions split the metadata (in Conv3dConfiguration) from - // the data (ConvArguments, which only has pointers and a single - // enum value). Thus, this class will need both the - // configuration and the (user's input) arguments to set up the - // kernel's arguments. This function can fill in what the - // configuration has now, but the class will need the user's - // input arguments later. - if (config.split_k_mode != conv::SplitKMode::kSerial) { - CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); - return Status::kInvalid; - } - // config.problem_size.split_k_slices is only meaningful if - // split_k_mode != kSerial. If this code later supports other - // split_k_mode values, then it will also need to read - // split_k_slices. - - const int N = config.problem_size.N; - const int D = config.problem_size.D; - const int H = config.problem_size.H; - const int W = config.problem_size.W; - const int C = config.problem_size.C; - const int K = config.problem_size.K; - const int T = config.problem_size.T; - const int R = config.problem_size.R; - const int S = config.problem_size.S; - const int pad_d = config.problem_size.pad_d; - const int pad_h = config.problem_size.pad_h; - const int pad_w = config.problem_size.pad_w; - const int traversal_stride_d = config.problem_size.stride_d; - const int traversal_stride_h = config.problem_size.stride_h; - const int traversal_stride_w = config.problem_size.stride_w; - const int dilation_d = config.problem_size.dilation_d; - const int dilation_h = config.problem_size.dilation_h; - const int dilation_w = config.problem_size.dilation_w; - - // CUTLASS 3's implicit GEMM convolution kernels currently only - // support cross correlation (passing over the activation and - // filter tensors in the same order). The convolution mode is - // future work. - const auto mode = config.problem_size.mode; - if (mode != cutlass::conv::Mode::kCrossCorrelation) { - CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " - "are not currently supported."); - return Status::kInvalid; - } - - using Stride = cutlass::layout::TensorNDHWC::Stride; - static_assert(std::is_same_v>); - - const cutlass::library::ConvKind conv_kind = [] () { - constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp; - if constexpr (op == cutlass::conv::Operator::kFprop) { - return library::ConvKind::kFprop; - } - else if constexpr (op == cutlass::conv::Operator::kDgrad) { - return library::ConvKind::kDgrad; - } - else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ { - return library::ConvKind::kWgrad; - } - } (); - const Stride input_stride_a = config.layout_a(conv_kind).stride(); - const Stride input_stride_b = config.layout_b(conv_kind).stride(); - const Stride input_stride_c = config.layout_c(conv_kind).stride(); - -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; - std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" - << " stride_size = " << stride_size << "\n"; - auto print_stride = [] (Stride const& stride, char const variable_name[]) { - std::cerr << " " << variable_name << ": ["; - for (size_t k = 0; k < Stride::kRank; ++k) { - std::cerr << stride[static_cast(k)]; - if (k + 1u < Stride::kRank) { - std::cerr << ", "; - } - } - std::cerr << "]\n"; - }; - print_stride(input_stride_a, "input_stride_a"); - print_stride(input_stride_b, "input_stride_b"); - print_stride(input_stride_c, "input_stride_c"); -#endif - // Conv3dConfiguration stores the strides as Coord (with - // compile-time size), so there's no need to check sizes here - // (unlike Conv2dConfiguration, which stores strides as - // std::vector). - - constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; - using problem_shape_type = - cutlass::conv::ConvProblemShape; - // cute::array; must convert to the kernel's native strides - using TensorStride = typename problem_shape_type::TensorStride; - - const TensorStride stride_A = coord_to_array_strides(input_stride_a); - const TensorStride stride_B = coord_to_array_strides(input_stride_b); - const TensorStride stride_C = coord_to_array_strides(input_stride_c); - - const int num_groups = config.problem_size.groups; - if (num_groups != 1) { - CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); - return Status::kInvalid; - } - // ConvProblemShape is how CUTLASS 3 kernels represent - // convolution problems. ConvProblemShape's constructors take - // shape_act, stride_act, shape_flt, and stride_flt, and set - // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C - // according to Fprop / Dgrad / Wgrad. - // - // Conv3dConfiguration differs a bit from Conv2dConfiguration, - // but the idea is the same: the "input_stride_a" from config - // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act - // isn't always input_stride_a. Analogously, stride_flt isn't - // always input_stride_b. The code here "undoes" the logic in - // config.layout_a(conv_kind) and config.layout_b(conv_kind) - // (analogous to Conv2dWorkspace::set_stride_vector) so that we - // can recover the strides of the activation and filter tensors. - // It doesn't need to worry about the so-called "output" tensor - // (which might not be C), as ConvProblemShape's constructor - // figures out its shapes and strides. - using TensorExtent = typename problem_shape_type::TensorExtent; - TensorExtent shape_act{N, D, H, W, C}; - auto stride_act = [&] () { - // Some compilers consider conv_op (defined above), as - // captured by this lambda, as "not a constant expression." - constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; - if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { - return stride_A; - } - else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { - return stride_C; - } - else { // conv_kind == cutlass::conv::Operator::kWgrad - return stride_B; - } - } (); - TensorExtent shape_flt{K, T, R, S, C}; - auto stride_flt = [&] () { - // Some compilers consider conv_op (defined above), as - // captured by this lambda, as "not a constant expression." - constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; - if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { - return stride_B; - } - else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { - return stride_B; - } - else { // conv_kind == cutlass::conv::Operator::kWgrad - return stride_C; - } - } (); - - problem_shape_type problem_shape( - /* mode = */ mode, - /* shape_act = */ shape_act, - /* stride_act = */ stride_act, - /* shape_flt = */ shape_flt, - /* stride_flt = */ stride_flt, - /* lower_padding = */ {pad_d, pad_h, pad_w}, - /* upper_padding = */ {pad_d, pad_h, pad_w}, - /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, - /* dilation = */ {dilation_d, dilation_h, dilation_w}, - num_groups); - out_args.problem_shape = problem_shape; - - // ConvProblemShape's constructor sets its shape_C member. -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - printf("\n problem_shape.shape_C: "); - print(problem_shape.shape_C); - printf("\n problem_shape.stride_C: "); - print(problem_shape.stride_C); - printf("\n"); -#endif - // Initialization of C's and D's strides follows the CUTLASS 3 - // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). - { - using StrideC = typename Operator::ConvKernel::StrideC; - using StrideD = typename Operator::ConvKernel::StrideD; - auto stride_C = StrideC{}; - auto stride_D = StrideD{}; - - if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { - stride_C = cutlass::make_cute_packed_stride( - StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); - stride_D = cutlass::make_cute_packed_stride( - StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; -#endif - } - else { - cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " - << stride_C_i << "\n"; -#endif - cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - }); - cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " - << stride_D_i << "\n"; -#endif - cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; - }); - } - out_args.epilogue.dC = stride_C; - out_args.epilogue.dD = stride_D; - } - return Status::kSuccess; - } - } - - Status update_operator_arguments_from_arguments( - typename Operator::Arguments& out_args, - ConvArguments const& in_args) const - { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); -#endif - auto status = UpdateFusionArgs::update_( - out_args.epilogue.thread, in_args); - if (status != Status::kSuccess) { - return status; - } - - out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); - out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); - - out_args.epilogue.ptr_C = reinterpret_cast(in_args.C); - out_args.epilogue.ptr_D = reinterpret_cast(in_args.D); - - return Status::kSuccess; - } -}; - -} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h deleted file mode 100644 index 880cb4bf34b1f3d946e1dc86b80806309bb2b3c1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h +++ /dev/null @@ -1,1408 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -*/ - -#pragma once -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_sparse.h" -#include "cutlass/gemm/device/gemm_complex.h" -#include "cutlass/gemm/device/gemm_batched.h" -#include "cutlass/gemm/device/gemm_array.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmOperationBase : public Operation { -public: - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - // assuming all tensors use same type for StrideIndex - using StrideIndex = typename Operator::LayoutA::Index; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - GemmDescription description_; - -public: - - /// Constructor - GemmOperationBase(char const *name = "unknown_gemm") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kGemm; - description_.gemm_kind = GemmKind::kGemm; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::GemmKernel::WarpCount::kM, - Operator::GemmKernel::WarpCount::kN, - Operator::GemmKernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.D = make_TensorDescription(Operator::kAlignmentC); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - description_.transform_A = ComplexTransformMap::kId; - description_.transform_B = ComplexTransformMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - - this->description_.gemm_kind = GemmKind::kGemm; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - GemmConfiguration const *configuration) { - - operator_args.problem_size = configuration->problem_size; - - operator_args.ref_A = {nullptr, configuration->lda}; - operator_args.ref_B = {nullptr, configuration->ldb}; - operator_args.ref_C = {nullptr, configuration->ldc}; - operator_args.ref_D = {nullptr, configuration->ldd}; - - operator_args.split_k_slices = configuration->split_k_slices; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - operator_args.ref_A.reset(static_cast(arguments->A)); - operator_args.ref_B.reset(static_cast(arguments->B)); - operator_args.ref_C.reset(static_cast(arguments->C)); - operator_args.ref_D.reset(static_cast(arguments->D)); - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - return op->initialize(args, device_workspace, stream); - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - return op->run(stream); - } - - void print_operator_args(OperatorArguments &operator_args) const { -#if 0 - std::cout << "GemmOperation::OperatorArguments" << std::endl; - std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; - std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; - std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; - std::cout << " beta: " << operator_args.epilogue.beta << std::endl; - std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; - std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; - std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; - std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; - std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; - std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; - std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmSparseOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementE = typename Operator::ElementE; - using LayoutE = typename Operator::LayoutE; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - - this->description_.kind = OperationKind::kSparseGemm; - this->description_.gemm_kind = GemmKind::kSparse; - this->description_.E = make_TensorDescription(Operator::kAlignmentE); - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - SparseGemmConfiguration const *configuration) { - - operator_args.problem_size = configuration->problem_size; - operator_args.ref_A = {nullptr, configuration->lda}; - operator_args.ref_B = {nullptr, configuration->ldb}; - operator_args.ref_C = {nullptr, configuration->ldc}; - operator_args.ref_D = {nullptr, configuration->ldd}; - operator_args.ref_E = {nullptr, configuration->lde}; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - SparseGemmArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - operator_args.ref_A.reset(static_cast(arguments->A)); - operator_args.ref_B.reset(static_cast(arguments->B)); - operator_args.ref_C.reset(static_cast(arguments->C)); - operator_args.ref_D.reset(static_cast(arguments->D)); - operator_args.ref_E.reset(static_cast(arguments->E)); - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - SparseGemmConfiguration const *configuration = - static_cast(configuration_ptr); - - SparseGemmArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - return op->initialize(args, device_workspace, stream); - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - return op->run(stream); - } - - void print_operator_args(OperatorArguments &operator_args) const { -#if 0 - std::cout << "GemmOperation::OperatorArguments" << std::endl; - std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; - std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; - std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; - std::cout << " beta: " << operator_args.epilogue.beta << std::endl; - std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; - std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; - std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; - std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; - std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; - std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; - std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmUniversalOperation(char const *name = "unknown_gemm"): - GemmOperationBase(name) { - - this->description_.gemm_kind = GemmKind::kUniversal; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - GemmUniversalConfiguration const *configuration) { - - operator_args.mode = configuration->mode; - - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda = (configuration->lda); - operator_args.ldb = (configuration->ldb); - operator_args.ldc = (configuration->ldc); - operator_args.ldd = (configuration->ldd); - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmUniversalArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A = arguments->A; - operator_args.ptr_B = arguments->B; - operator_args.ptr_C = arguments->C; - operator_args.ptr_D = arguments->D; - - operator_args.batch_stride_A = arguments->batch_stride_A; - operator_args.batch_stride_B = arguments->batch_stride_B; - operator_args.batch_stride_C = arguments->batch_stride_C; - operator_args.batch_stride_D = arguments->batch_stride_D; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmUniversalConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmUniversalArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmPlanarComplexOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - - this->description_.gemm_kind = GemmKind::kPlanarComplex; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - GemmPlanarComplexConfiguration const *configuration) { - - operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - - operator_args.lda_real = configuration->lda_real; - operator_args.lda_imag = configuration->lda_imag; - operator_args.ldb_real = configuration->ldb_real; - operator_args.ldb_imag = configuration->ldb_imag; - operator_args.ldc_real = configuration->ldc_real; - operator_args.ldc_imag = configuration->ldc_imag; - operator_args.ldd_real = configuration->ldd_real; - operator_args.ldd_imag = configuration->ldd_imag; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmPlanarComplexArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast const *>(arguments->alpha), - *static_cast const *>(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast const *>(arguments->alpha), - static_cast const *>(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A_real = arguments->A_real; - operator_args.ptr_A_imag = arguments->A_imag; - operator_args.ptr_B_real = arguments->B_real; - operator_args.ptr_B_imag = arguments->B_imag; - operator_args.ptr_C_real = arguments->C_real; - operator_args.ptr_C_imag = arguments->C_imag; - operator_args.ptr_D_real = arguments->D_real; - operator_args.ptr_D_imag = arguments->D_imag; - - operator_args.batch_stride_A = arguments->batch_stride_A_real; - operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; - operator_args.batch_stride_B = arguments->batch_stride_B_real; - operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; - operator_args.batch_stride_C = arguments->batch_stride_C_real; - operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; - operator_args.batch_stride_D = arguments->batch_stride_D_real; - operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmPlanarComplexConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmPlanarComplexArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmPlanarComplexArrayOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - - this->description_.gemm_kind = GemmKind::kPlanarComplexArray; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - GemmPlanarComplexArrayConfiguration const *configuration) { - - operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda_real = configuration->lda_real; - operator_args.lda_imag = configuration->lda_imag; - operator_args.ldb_real = configuration->ldb_real; - operator_args.ldb_imag = configuration->ldb_imag; - operator_args.ldc_real = configuration->ldc_real; - operator_args.ldc_imag = configuration->ldc_imag; - operator_args.ldd_real = configuration->ldd_real; - operator_args.ldd_imag = configuration->ldd_imag; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmPlanarComplexArrayArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast const *>(arguments->alpha), - *static_cast const *>(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast const *>(arguments->alpha), - static_cast const *>(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A_real = arguments->A_real; - operator_args.ptr_A_imag = arguments->A_imag; - operator_args.ptr_B_real = arguments->B_real; - operator_args.ptr_B_imag = arguments->B_imag; - operator_args.ptr_C_real = arguments->C_real; - operator_args.ptr_C_imag = arguments->C_imag; - operator_args.ptr_D_real = arguments->D_real; - operator_args.ptr_D_imag = arguments->D_imag; - - operator_args.ptr_M = arguments->M; - operator_args.ptr_N = arguments->N; - operator_args.ptr_K = arguments->K; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmPlanarComplexArrayConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmPlanarComplexArrayArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmGroupedOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = ElementC; - using LayoutD = LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - GemmGroupedOperation(char const *name = "unknown_gemm"): - GemmOperationBase(name) { - - this->description_.kind = OperationKind::kGroupedGemm; - this->description_.provider = Provider::kCUTLASS; - this->threadblock_count = Operator::sufficient(); - - this->description_.gemm = GemmOperationBase::description_; - this->description_.gemm.gemm_kind = GemmKind::kGrouped; - this->description_.tile_description = this->description_.gemm.tile_description; - } - - /// Returns the description of the GroupedGEMM operation - virtual OperationDescription const & description() const override final { - return description_; - } - - -private: - int threadblock_count; - GroupedGemmDescription description_; - -protected: - - /// Constructs the arguments structure given the configuration and arguments - Status construct_arguments_( - OperatorArguments &op_args, - GemmGroupedConfiguration const *config) const { - - op_args.problem_count = config->problem_count; - op_args.threadblock_count = threadblock_count; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - Status update_arguments_( - OperatorArguments &op_args, - GemmGroupedArguments const *arguments) const { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - - op_args.output_op = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { - - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - - op_args.output_op = params; - } - else { - return Status::kErrorInvalidProblem; - } - - op_args.threadblock_count = threadblock_count; - op_args.problem_count = arguments->problem_count; - op_args.problem_sizes = arguments->problem_sizes; - - op_args.ptr_A = static_cast(arguments->ptr_A); - op_args.ptr_B = static_cast(arguments->ptr_B); - op_args.ptr_C = static_cast(arguments->ptr_C); - op_args.ptr_D = static_cast(arguments->ptr_D); - - op_args.lda = arguments->lda; - op_args.ldb = arguments->ldb; - op_args.ldc = arguments->ldc; - op_args.ldd = arguments->ldd; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmGroupedConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmGroupedArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args); - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp deleted file mode 100644 index 2c1d17943f11fe8126b3070c3fcead5598e2d207..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp +++ /dev/null @@ -1,714 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/collective.hpp" -#include "cutlass/array.h" -#include "cutlass/array_subbyte.h" -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/mixed_dtype_utils.hpp" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cute/tensor.hpp" -#include - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmOperation3xBase : public Operation { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - // assuming all tensors use same type for StrideIndex - using StrideIndex = typename Operator::LayoutA::Index; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - -protected: - GemmDescription description_; - -public: - - /// Constructor - GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kGemm; - description_.gemm_kind = gemm_kind_; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { - description_.tile_description.cluster_shape = make_Coord( - Operator::ClusterShape::kM, - Operator::ClusterShape::kN, - Operator::ClusterShape::kK); - } - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::WarpCount::kM, - Operator::WarpCount::kN, - Operator::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.D = make_TensorDescription(Operator::kAlignmentD); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - description_.transform_A = ComplexTransformMap::kId; - description_.transform_B = ComplexTransformMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - /// Returns the description of the GEMM operation - GemmDescription const& get_gemm_description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversal3xOperation : public GemmOperation3xBase { -public: - - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - - -public: - - /// Constructor - GemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) { - if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { - dim3 cluster_dims( - cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); - uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; - void const* kernel_ptr = (void*)(device_kernel); - max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( - cluster_dims, - threads_per_block, - kernel_ptr); - } - } - -private: - int max_active_clusters{}; - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { - // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides - // Do nothing here and construct kernel arguments in update_arguments_ instead - // We also cannot construct TMA descriptors without all the arguments available - - operator_args.mode = configuration->mode; - return Status::kSuccess; - } - - template - struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } - }; - - template class Policy, int Stages, class ClusterShape, class KernelSchedule> - static constexpr bool is_sm90_mixed_dtype_mainloop_(Policy policy) { - return (cute::is_same_v, - cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); - } - - template - static constexpr bool is_sm90_mixed_dtype_mainloop_(DispatchPolicy) { - return false; - } - - template < - typename ElementWide, - typename ElementNarrow, - typename ElementScaleMainloop, - class ActualStrideAB, - Sm90MixedInputWiderOperand wider_operand, - bool is_n4w8, - typename ElementScale, - typename ElementZero, - class Layout_SZ> - static void dequantize_encode_( - OperatorArguments &operator_args, - GemmUniversalArguments const *arguments, - cudaStream_t stream, - const int &problem_mn, - const int &problem_k, - const int &options_l, - const int &options_g, - ElementScale *ptr_S, - ElementZero *ptr_Z, - const size_t &SZ_size, - Layout_SZ layout_SZ - ) { - - auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); - auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); - auto layout_AB = cute::make_layout(shape_AB, stride_AB); - auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); - const ElementNarrow *ptr_AB = nullptr; - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - ptr_AB = static_cast(arguments->B); - } - else { - ptr_AB = static_cast(arguments->A); - } - dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); - if constexpr(is_n4w8) { - size_t AB_size = cute::size(layout_AB); - cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); - unified_encode_int4b(ptr_AB, encoded_AB, AB_size); - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - operator_args.mainloop.ptr_B = static_cast(encoded_AB); - } - else { - operator_args.mainloop.ptr_A = static_cast(encoded_AB); - } - ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); - pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); - } - } - - template < - typename ElementAB, - class ActualStrideAB, - class LayoutAB_Reordered, - class LayoutAtomQuant, - Sm90MixedInputWiderOperand wider_operand> - static void handle_shuffle_tensor_( - OperatorArguments &operator_args, - GemmUniversalArguments const *arguments, - const int &problem_mn, - const int &problem_k, - const int &options_l) { - - auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); - auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); - auto layout_AB = cute::make_layout(shape_AB, stride_AB); - LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - operator_args.mainloop.dB = layout_AB_reordered; - } - else { - operator_args.mainloop.dA = layout_AB_reordered; - } - if (arguments->generate_dequantized_AB) { - size_t AB_size = cute::size(layout_AB); - ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); - const ElementAB *AB_src = nullptr; - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - AB_src = static_cast(operator_args.mainloop.ptr_B); - } - else { - AB_src = static_cast(operator_args.mainloop.ptr_A); - } - reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); - ElementAB *AB_dst = static_cast(arguments->encoded_AB); - cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); - cutlass::device_memory::free(AB_reordered); - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - operator_args.mainloop.ptr_B = AB_dst; - } - else { - operator_args.mainloop.ptr_A = AB_dst; - } - } - } - - /// Constructs the arguments structure given the configuration and arguments - Status update_arguments_( - OperatorArguments& operator_args, - GemmUniversalArguments const* arguments, - cudaStream_t stream = nullptr) const { - Status status = Status::kSuccess; - - status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, *arguments); - if (status != Status::kSuccess) { - return status; - } - - // TODO: type erase Arguments structure in 3.0 GEMM - operator_args.problem_shape = cute::make_shape( - arguments->problem_size.m(), - arguments->problem_size.n(), - arguments->problem_size.k(), - arguments->batch_count); - - // update arguments - - if constexpr (IsRuntimeDataType) { - using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); - - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); - } - - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); - } - - } - else { - operator_args.mainloop.ptr_A = static_cast(arguments->A); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - } - operator_args.epilogue.ptr_C = static_cast(arguments->C); - operator_args.epilogue.ptr_D = static_cast(arguments->D); - - // Stride{A,B} is a Layout if and only if: - // (1) This is a mixed dtype kernel, and - // (2) This mixed dtype kernel is using shuffling, and - // (3) sizeof(narrow_type) == 4 or 8 bits, and - // (4) sizeof(wide_type) == 16 bits. - // If A/B has the narrow data type, Stride{A/B} will be a Layout - constexpr bool is_StrideA_Layout = cute::is_layout::value; - constexpr bool is_StrideB_Layout = cute::is_layout::value; - static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); - if constexpr(!is_StrideA_Layout) { - operator_args.mainloop.dA = cute::make_int_tuple_from( - arguments->lda, arguments->batch_stride_A); - } - if constexpr(!is_StrideB_Layout) { - operator_args.mainloop.dB = cute::make_int_tuple_from( - arguments->ldb, arguments->batch_stride_B); - } - operator_args.epilogue.dC = cute::make_int_tuple_from( - arguments->ldc, arguments->batch_stride_C); - operator_args.epilogue.dD = operator_args.epilogue.dC; - - using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; - if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{})) { - const int problem_m = arguments->problem_size.m(); - const int problem_n = arguments->problem_size.n(); - const int problem_k = arguments->problem_size.k(); - const int options_l = arguments->batch_count; - - constexpr Sm90MixedInputWiderOperand wider_operand = - (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? - Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; - using ElementWide = std::conditional_t; - using ElementNarrow = std::conditional_t; - - constexpr bool has_scale = !std::is_same_v; - constexpr bool has_zero = !std::is_same_v; - - const int options_g = problem_k; - const int scale_k = (problem_k + options_g - 1) / options_g; - - constexpr bool is_A4B8 = ( - cutlass::is_same_v && - (cutlass::is_same_v || - cutlass::is_same_v)); - constexpr bool is_A8B4 = ( - cutlass::is_same_v && - (cutlass::is_same_v || - cutlass::is_same_v)); - constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; - - // If this is a convert-only kernel, we still need to generate dequantized A or B for verification, - // and in this case ElementScale is the same as ElementWide - // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element - using DummyElementScaleMainloop = std::conditional_t< - is_int4_x_fp8, - typename cutlass::Array, - ElementWide - >; - using ElementScaleMainloop = std::conditional_t< - has_scale, - typename CollectiveMainloop::ElementScale, - DummyElementScaleMainloop - >; - using ElementScale = std::conditional_t< - has_scale, - typename UnderlyingElement::type, - ElementWide - >; - using StrideScale = typename CollectiveMainloop::StrideScale; - // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S - using ElementZero = std::conditional_t< - has_zero, - typename CollectiveMainloop::ElementZero, - ElementScale - >; - const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; - const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); - auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); - ElementScale *ptr_S = static_cast(arguments->Scale); - ElementZero *ptr_Z = static_cast(arguments->Zero); - - // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled - if (arguments->generate_scale_and_zero) { - float scale_min = 1.0f, scale_max = 1.0f; - if constexpr(has_scale) { - const float elt_max_f = float(cutlass::platform::numeric_limits::max()); - // Need to fix max_dequant_val and min_dequant_val? - const float max_dequant_val = elt_max_f * 0.25f; - const float min_dequant_val = 0.5f; - scale_max = max_dequant_val / elt_max_f; - scale_min = min_dequant_val / elt_max_f; - } - uint64_t seed = 2023; - cutlass::reference::device::BlockFillRandomUniform( - ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); - - // In ScaleOnly mode, set Z as zero for generating dequantized A or B - const float zero_max = has_zero ? 2.0f : 0.0f; - const float zero_min = has_zero ? -2.0f : 0.0f; - cutlass::reference::device::BlockFillRandomUniform( - ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); - } // End of "if (arguments->generate_scale_and_zero)" - - // 2. Generate the dequantized A or B for verification - if (arguments->generate_dequantized_AB) { - StrideScale stride_SZ = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); - auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { - if constexpr(is_StrideB_Layout) { - // The generator only generates row-major A and col-major B at the moment - // Need a way to read out the actual layout of B later - using ActualLayoutB = cutlass::layout::ColumnMajor; - using ActualStrideB = cutlass::detail::TagToStrideB_t; - dequantize_encode_( - operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); - } - else { - using ActualStrideB = typename CollectiveMainloop::StrideB; - dequantize_encode_( - operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); - } - } - else { - if constexpr(is_StrideA_Layout) { - // The generator only generates row-major A and col-major B at the moment - // Need a way to read out the actual layout of A later - using ActualLayoutA = cutlass::layout::RowMajor; - using ActualStrideA = cutlass::detail::TagToStrideA_t; - dequantize_encode_( - operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); - } - else { - using ActualStrideA = typename CollectiveMainloop::StrideA; - dequantize_encode_( - operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); - } - } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" - } // End of "if (arguments->generate_dequantized_AB)" - - // 3. Put Scale and Zero in mainloop - if constexpr(has_scale) { - if constexpr(is_int4_x_fp8) { - operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); - } - else { - operator_args.mainloop.ptr_S = static_cast(arguments->Scale); - } - operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); - operator_args.mainloop.group_size = options_g; - if constexpr(has_zero) { - operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); - } - } // End of "if constexpr(has_scale)" - - // Handle the shuffling - using ValueShuffle = std::conditional_t< - cutlass::sizeof_bits::value == 4, - cute::Layout, cute::Stride>, - cute::Layout, cute::Stride> - >; - constexpr int NumShuffleAtoms = 1; - using MmaAtomShape = cute::Layout>>; - using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); - // The generator only generates row-major A and col-major B at the moment - // Need a way to read out the actual layout and stride of A/B later - if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { - using ActualLayoutB = cutlass::layout::ColumnMajor; - using ActualStrideB = cutlass::detail::TagToStrideB_t; - using LayoutB_Reordered = typename CollectiveMainloop::StrideB; - handle_shuffle_tensor_( - operator_args, arguments, problem_n, problem_k, options_l); - } - if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { - using ActualLayoutA = cutlass::layout::RowMajor; - using ActualStrideA = cutlass::detail::TagToStrideA_t; - using LayoutA_Reordered = typename CollectiveMainloop::StrideA; - handle_shuffle_tensor_( - operator_args, arguments, problem_m, problem_k, options_l); - } - } // End of "if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{}))" - - /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ - operator_args.hw_info.sm_count = arguments->sm_count; - if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { - operator_args.hw_info.max_active_clusters = max_active_clusters; - } - if constexpr (!std::is_const_v) { - operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; - } - - if constexpr (!std::is_const_v) { - using Enum_t = decltype(operator_args.scheduler.raster_order); - switch (arguments->raster_order) { - case RasterOrder::kAlongN: - operator_args.scheduler.raster_order = Enum_t::AlongN; - break; - case RasterOrder::kAlongM: - operator_args.scheduler.raster_order = Enum_t::AlongM; - break; - default: - operator_args.scheduler.raster_order = Enum_t::Heuristic; - } - } - - if constexpr (std::is_same_v) { - operator_args.scheduler.splits = arguments->split_k_slices; - } - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { - operator_args.hw_info.cluster_shape = dim3( - arguments->cluster_shape.m(), - arguments->cluster_shape.n(), - arguments->cluster_shape.k()); - operator_args.hw_info.cluster_shape_fallback = dim3( - arguments->cluster_shape_fallback.m(), - arguments->cluster_shape_fallback.n(), - arguments->cluster_shape_fallback.k()); - } - return status; - } - -public: - - /// Returns success if the operation can proceed - Status can_implement( - [[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override { - GemmUniversalArguments const *arguments = - static_cast(arguments_ptr); - OperatorArguments args; - - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - Status can_impl = Operator::can_implement(args); - - //return Operator::can_implement(args); - return can_impl; - } - - /// Gets the host-side workspace - uint64_t get_host_workspace_size(void const *configuration) const override { - return sizeof(Operator); - } - - /// Gets the device-side workspace - uint64_t get_device_workspace_size( - void const *configuration_ptr,void const *arguments_ptr) const override { - - OperatorArguments args; - auto status = update_arguments_( - args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const override { - Operator *op = new (host_workspace) Operator; - return Status::kSuccess; - } - - /// Runs the kernel - Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments args; - Status status = update_arguments_(args, static_cast(arguments_ptr), stream); - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(args, device_workspace, stream, nullptr, - static_cast(arguments_ptr)->use_pdl); - return status; - } -}; -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::library - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp deleted file mode 100644 index 91f618d4fab74a6d43e2d82c572d215d5bea5a1c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp +++ /dev/null @@ -1,873 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all grouped GEMM operations in CUTLASS Library. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/collective.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "gemm_operation_3x.hpp" -#include "library_internal.h" - -namespace cutlass::library { - -template -class GroupedGemmOperation3xBase : public GemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - - GroupedGemmOperation3xBase(char const* name = "unknown_gemm") - : GemmOperation3xBase(name, GemmKind::kGrouped) { - this->description_.kind = OperationKind::kGroupedGemm; - this->description_.name = name; - this->description_.provider = Provider::kCUTLASS; - - this->description_.gemm = GemmOperation3xBase::description_; - this->description_.tile_description = this->description_.gemm.tile_description; - }; - -public: - mutable CudaBuffer strideA_device; - mutable CudaBuffer strideB_device; - mutable CudaBuffer strideC_device; - mutable CudaBuffer strideD_device; - - /// Returns the description of the GEMM operation - virtual OperationDescription const& description() const override final { return description_; } - /// Gets the host-side workspace - uint64_t get_host_workspace_size(void const* configuration) const override final { - return sizeof(Operator); - } - -protected: - library::GroupedGemmDescription description_; - - Status initialize_strides(GemmGroupedConfiguration const& config) const { - auto const num_groups = config.problem_count; - this->strideA_device = - CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); - this->strideB_device = - CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); - this->strideC_device = - CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); - this->strideD_device = - CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); - - std::vector strideA_host(num_groups); - std::vector strideB_host(num_groups); - std::vector strideC_host(num_groups); - std::vector strideD_host(num_groups); - for (int group_idx = 0; group_idx < num_groups; group_idx++) { - strideA_host[group_idx] = - cute::make_int_tuple_from( - config.lda[group_idx]); - strideB_host[group_idx] = - cute::make_int_tuple_from( - config.ldb[group_idx]); - strideC_host[group_idx] = - cute::make_int_tuple_from( - config.ldc[group_idx]); - strideD_host[group_idx] = - cute::make_int_tuple_from( - config.ldc[group_idx]); - } - CUDA_CHECK(cudaMemcpy( - this->strideA_device.data(), - strideA_host.data(), - sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy( - this->strideB_device.data(), - strideB_host.data(), - sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy( - this->strideC_device.data(), - strideC_host.data(), - sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy( - this->strideD_device.data(), - strideD_host.data(), - sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, - cudaMemcpyHostToDevice)); - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - Status update_arguments_base( - OperatorArguments& operator_args, - GemmGroupedArguments const& arguments) const { - operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; - operator_args.problem_shape = { - arguments.problem_count, - arguments.problem_sizes_3x, - arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host - : nullptr}; - - if constexpr (IsRuntimeDataType) { - using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); - operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); - - using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; - using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; - - static_assert(cute::is_same_v, - "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); - using RuntimeDatatypeArg = RuntimeDataTypeA; - - auto mapping = [](RuntimeDatatype type) { - if constexpr (cute::is_same_v) { - if (type == RuntimeDatatype::kE5M2) { - return cute::UMMA::MXF8F6F4Format::E5M2; - } - else if (type == RuntimeDatatype::kE4M3) { - return cute::UMMA::MXF8F6F4Format::E4M3; - } - else if (type == RuntimeDatatype::kE3M2) { - return cute::UMMA::MXF8F6F4Format::E3M2; - } - else if (type == RuntimeDatatype::kE2M3) { - return cute::UMMA::MXF8F6F4Format::E2M3; - } - else if (type == RuntimeDatatype::kE2M1) { - return cute::UMMA::MXF8F6F4Format::E2M1; - } - else { - #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 - std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; - #endif - return cute::UMMA::MXF8F6F4Format::E4M3; - } - } - else if constexpr (cute::is_same_v) { - if (type == RuntimeDatatype::kE2M1) { - return cute::UMMA::MXF4Format::E2M1; - } - else { - #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 - std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; - #endif - return cute::UMMA::MXF4Format::E2M1; - } - } - // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype - CUTE_GCC_UNREACHABLE; - }; - operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); - operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); - } - else { - operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); - operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); - } - operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); - operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); - - operator_args.mainloop.dA = - static_cast(this->strideA_device.data()); - operator_args.mainloop.dB = - static_cast(this->strideB_device.data()); - operator_args.epilogue.dC = - static_cast(this->strideC_device.data()); - operator_args.epilogue.dD = - static_cast(this->strideD_device.data()); - - /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ - operator_args.hw_info.sm_count = arguments.sm_count; - if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { - operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; - } - if constexpr (!std::is_const_v) { - operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; - } - - if constexpr (!std::is_const_v) { - using Enum_t = decltype(operator_args.scheduler.raster_order); - switch (arguments.raster_order) { - case RasterOrder::kAlongN: - operator_args.scheduler.raster_order = Enum_t::AlongN; - break; - case RasterOrder::kAlongM: - operator_args.scheduler.raster_order = Enum_t::AlongM; - break; - default: - operator_args.scheduler.raster_order = Enum_t::Heuristic; - } - } - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { - operator_args.hw_info.cluster_shape = - dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); - operator_args.hw_info.cluster_shape_fallback = dim3( - arguments.cluster_shape_fallback.m(), - arguments.cluster_shape_fallback.n(), - arguments.cluster_shape_fallback.k()); - } - return Status::kSuccess; - } - - template - static Status update_fusion_args(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = nullptr; - fusion_args.beta_ptr_array = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - fusion_args.alpha_ptr_array = nullptr; - fusion_args.beta_ptr_array = nullptr; - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } -}; - -/// **** CAUTION **** -/// Unlike other operations, initialize() must be called when -/// certain arguments change. See initialize() for details. -template -class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - -public: - GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") - : GroupedGemmOperation3xBase(name) {} - - ~GroupedGemmUniversal3xOperation() override = default; - -private: - int max_active_clusters{}; - -protected: - template struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { - return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); - } - }; - - /// Constructs the arguments structure given the configuration and arguments - Status - update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { - - Status status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, - *arguments); - if (status != Status::kSuccess) { - return status; - } - - status = this->update_arguments_base(operator_args, *arguments); - return status; - } - -public: - /// Returns success if the operation can proceed - Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) - const override { - GemmGroupedArguments const* arguments = static_cast(arguments_ptr); - OperatorArguments args; - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - status = Operator::can_implement(args); - return status; - } - - /// Gets the device-side workspace - uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) - const override { - - OperatorArguments args; - auto status = update_arguments_(args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - /// **** CAUTION **** - /// Must be called when lda, ldb, ldc, or ldd change. - /// The CUTLASS library stores the operations in a type- - /// erased manifest. Therefore, only this class knows - /// the type of strideA, strideB, strideC, and strideD. - /// Since grouped GEMM needs to allocate storage for - /// the strides on device, the concrete type of the stride - /// must be known in order to copy in the correct memory - /// layout on device. - Status initialize( - void const* configuration_ptr, - void* host_workspace, - void* device_workspace, - cudaStream_t stream = nullptr) const override { - - Operator* op = new (host_workspace) Operator; - - auto const& config = *static_cast(configuration_ptr); - return this->initialize_strides(config); - } - - /// **** CAUTION **** - /// initialize() must be called if lda, ldb, ldc, or ldd change. - Status run( - void const* arguments_ptr, - void* host_workspace, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments operator_args; - auto const& args = *static_cast(arguments_ptr); - - Status status = update_arguments_(operator_args, &args); - if (status != Status::kSuccess) { - return status; - } - - Operator* op = static_cast(host_workspace); - // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); - return status; - } - - // Set arguments that should only be set once before verifying or profiling the kernel. - // This should encompass any expensive operations that don't vary from run to run - // (e.g., max_active_clusters). - Status initialize_with_arguments(void* arguments_ptr) const override { - if constexpr (Operator::ArchTag::kMinComputeCapability < 90) { - return Status::kSuccess; - } - - GemmGroupedArguments* args = static_cast(arguments_ptr); - - dim3 cluster_dims; - if constexpr (cute::is_static_v) { - cluster_dims = dim3( - cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<2>(typename Operator::GemmKernel::ClusterShape{}) - ); - } - else { - cluster_dims = dim3( - args->cluster_shape.m(), - args->cluster_shape.n(), - args->cluster_shape.k() - ); - } - - uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; - void const* kernel_ptr = (void*)(device_kernel); - args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( - cluster_dims, - threads_per_block, - kernel_ptr); - - if (args->max_active_clusters == 0) { - std::cerr << "Max Active Clusters could not be queried. " - << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; - } - - return Status::kSuccess; - } -}; - -template -class GroupedBlockScaledGemmUniversal3xOperation : public GroupedGemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; - using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; - - using TiledMma = typename Operator::CollectiveMainloop::TiledMma; - constexpr static int SFVecSize = TiledMma::SFVecSize; - - - static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; - static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; - using ElementSFD = cute::conditional_t; - using LayoutSFD = cute::conditional_t; - - GroupedBlockScaledGemmUniversal3xOperation(char const* name = "unknown_gemm") - : GroupedGemmOperation3xBase(name) { - - BlockScaleDescription block_scaled_desc{}; - block_scaled_desc.kind = OperationKind::kBlockScaledGemm; - block_scaled_desc.SFA.element = NumericTypeMap::kId; - block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor; - block_scaled_desc.SFA.alignment = 128; - block_scaled_desc.SFA.log_extent_range = 32; - block_scaled_desc.SFA.log_stride_range = 32; - - block_scaled_desc.SFB.element = NumericTypeMap::kId; - block_scaled_desc.SFB.layout = LayoutTypeID::kRowMajor; - block_scaled_desc.SFB.alignment = 128; - block_scaled_desc.SFB.log_extent_range = 32; - block_scaled_desc.SFB.log_stride_range = 32; - - block_scaled_desc.SFMVecSize = 1; - block_scaled_desc.SFNVecSize = 1; - block_scaled_desc.SFKVecSize = SFVecSize; - - block_scaled_desc.SFD = make_TensorDescription(128); - block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize; - - this->description_.block_scales = block_scaled_desc; - } - - ~GroupedBlockScaledGemmUniversal3xOperation() override = default; - - mutable CudaBuffer layout_SFA_device; - mutable CudaBuffer layout_SFB_device; - -protected: - template struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status - update_(FusionArgs& fusion_args, GroupedGemmBlockScaledArguments const& arguments) { - - if constexpr (epilogue_scalefactor_generation) { - fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); - fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); - } - - return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); - } - }; - -public: - /// Returns success if the operation can proceed - Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) - const override { - GroupedGemmBlockScaledArguments const* arguments = - static_cast(arguments_ptr); - OperatorArguments args; - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - status = Operator::can_implement(args); - return status; - } - - Status update_arguments_( - OperatorArguments& operator_args, - GroupedGemmBlockScaledArguments const* arguments) const { - Status status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, - *arguments); - if (status != Status::kSuccess) { - return status; - } - - operator_args.mainloop.ptr_SFA = - static_cast(arguments->SFA); - operator_args.mainloop.ptr_SFB = - static_cast(arguments->SFB); - - operator_args.mainloop.layout_SFA = - static_cast(this->layout_SFA_device.data()); - operator_args.mainloop.layout_SFB = - static_cast(this->layout_SFB_device.data()); - - return this->update_arguments_base(operator_args, *arguments); - } - - uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) - const override { - - OperatorArguments args; - auto status = - update_arguments_(args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - /// **** CAUTION **** - /// Must be called when lda, ldb, ldc, or ldd change. - /// The CUTLASS library stores the operations in a type- - /// erased manifest. Therefore, only this class knows - /// the type of strideA, strideB, strideC, and strideD. - /// Since grouped GEMM needs to allocate storage for - /// the strides on device, the concrete type of the stride - /// must be known in order to copy in the correct memory - /// layout on device. - Status initialize( - void const* configuration_ptr, - void* host_workspace, - void* device_workspace, - cudaStream_t stream = nullptr) const override { - - auto const& config = *static_cast(configuration_ptr); - auto status = this->initialize_strides(config); - if (status != Status::kSuccess) { - return status; - } - - auto num_groups = config.problem_count; - this->layout_SFA_device = - CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); - this->layout_SFB_device = - CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); - auto layout_SFA_host = std::vector(num_groups); - auto layout_SFB_host = std::vector(num_groups); - - for (int group_idx = 0; group_idx < num_groups; group_idx++) { - auto const& shape = config.problem_sizes_3x_host[group_idx]; - auto M = get<0>(shape); - auto N = get<1>(shape); - auto K = get<2>(shape); - - auto layout_SFA = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - layout_SFA_host[group_idx] = layout_SFA; - layout_SFB_host[group_idx] = layout_SFB; - } - - CUDA_CHECK(cudaMemcpy( - this->layout_SFA_device.data(), - layout_SFA_host.data(), - sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy( - this->layout_SFB_device.data(), - layout_SFB_host.data(), - sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, - cudaMemcpyHostToDevice)); - - Operator* op = new (host_workspace) Operator; - return status; - } - - /// **** CAUTION **** - /// initialize() must be called if lda, ldb, ldc, or ldd change. - Status run( - void const* arguments_ptr, - void* host_workspace, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments operator_args; - auto const& args = *static_cast(arguments_ptr); - - Status status = update_arguments_(operator_args, &args); - if (status != Status::kSuccess) { - return status; - } - - Operator* op = static_cast(host_workspace); - status = op->run(operator_args, device_workspace, stream, nullptr); - return status; - } -}; - -template -class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase { -public: - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - using ElementSFA = typename Operator::ElementAccumulator; - using ElementSFB = typename Operator::ElementAccumulator; - - using TiledMma = typename Operator::CollectiveMainloop::TiledMma; - - GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm") - : GroupedGemmOperation3xBase(name) { - - BlockScaleDescription blockwise_desc{}; - blockwise_desc.kind = OperationKind::kBlockwiseGemm; - blockwise_desc.SFA.element = NumericTypeMap::kId; - blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ? - LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; - blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA; - blockwise_desc.SFA.log_extent_range = 32; - blockwise_desc.SFA.log_stride_range = 32; - - blockwise_desc.SFB.element = NumericTypeMap::kId; - blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ? - LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; - blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA; - blockwise_desc.SFB.log_extent_range = 32; - blockwise_desc.SFB.log_stride_range = 32; - - blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; - blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; - blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; - - blockwise_desc.EpilogueSFVecSize = 0; - - this->description_.block_scales = blockwise_desc; - } - - ~GroupedBlockwiseGemmUniversal3xOperation() override = default; - - mutable CudaBuffer layout_SFA_device; - mutable CudaBuffer layout_SFB_device; - -protected: - template struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status - update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) { - return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); - } - }; - -public: - /// Returns success if the operation can proceed - Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) - const override { - GroupedGemmBlockwiseArguments const* arguments = - static_cast(arguments_ptr); - OperatorArguments args; - auto status = update_arguments_(args, arguments); - if (status != Status::kSuccess) { - return status; - } - - status = Operator::can_implement(args); - return status; - } - - Status update_arguments_( - OperatorArguments& operator_args, - GroupedGemmBlockwiseArguments const* arguments) const { - Status status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, - *arguments); - if (status != Status::kSuccess) { - return status; - } - - operator_args.mainloop.ptr_SFA = - static_cast(arguments->SFA); - operator_args.mainloop.ptr_SFB = - static_cast(arguments->SFB); - - operator_args.mainloop.layout_SFA = - static_cast(this->layout_SFA_device.data()); - operator_args.mainloop.layout_SFB = - static_cast(this->layout_SFB_device.data()); - - return this->update_arguments_base(operator_args, *arguments); - } - - uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) - const override { - - OperatorArguments args; - auto status = - update_arguments_(args, static_cast(arguments_ptr)); - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - return size; - } - - /// Initializes the workspace - /// **** CAUTION **** - /// Must be called when lda, ldb, ldc, or ldd change. - /// The CUTLASS library stores the operations in a type- - /// erased manifest. Therefore, only this class knows - /// the type of strideA, strideB, strideC, and strideD. - /// Since grouped GEMM needs to allocate storage for - /// the strides on device, the concrete type of the stride - /// must be known in order to copy in the correct memory - /// layout on device. - Status initialize( - void const* configuration_ptr, - void* host_workspace, - void* device_workspace, - cudaStream_t stream = nullptr) const override { - - auto const& config = *static_cast(configuration_ptr); - auto status = this->initialize_strides(config); - if (status != Status::kSuccess) { - return status; - } - - auto num_groups = config.problem_count; - this->layout_SFA_device = - CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); - this->layout_SFB_device = - CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); - auto layout_SFA_host = std::vector(num_groups); - auto layout_SFB_host = std::vector(num_groups); - - for (int group_idx = 0; group_idx < num_groups; group_idx++) { - auto const& shape = config.problem_sizes_3x_host[group_idx]; - auto M = get<0>(shape); - auto N = get<1>(shape); - auto K = get<2>(shape); - - auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - layout_SFA_host[group_idx] = layout_SFA; - layout_SFB_host[group_idx] = layout_SFB; - } - - CUDA_CHECK(cudaMemcpy( - this->layout_SFA_device.data(), - layout_SFA_host.data(), - sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy( - this->layout_SFB_device.data(), - layout_SFB_host.data(), - sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, - cudaMemcpyHostToDevice)); - - Operator* op = new (host_workspace) Operator; - return status; - } - - /// **** CAUTION **** - /// initialize() must be called if lda, ldb, ldc, or ldd change. - Status run( - void const* arguments_ptr, - void* host_workspace, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { - - OperatorArguments operator_args; - auto const& args = *static_cast(arguments_ptr); - - Status status = update_arguments_(operator_args, &args); - if (status != Status::kSuccess) { - return status; - } - - Operator* op = static_cast(host_workspace); - status = op->run(operator_args, device_workspace, stream, nullptr); - return status; - } -}; - - -} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h deleted file mode 100644 index e8bd77397f3b85cce2da2a7a8e447ab6ccb48aea..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h +++ /dev/null @@ -1,427 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - - \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. - - Generally, - - description - compile-time constant parameters used to instantiate an operation - - configuration - runtime parameters with computationally expensive initialization - - arguments - runtime parameters that may be passed to an initialized operation with low - computational overhead -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/layout/matrix.h" - -#include "cutlass/library/library.h" -#include "cutlass/library/arch_mappings.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct NumericTypeMap; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kVoid; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kB1; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS2; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS4; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS8; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS16; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS32; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kS64; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU2; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU4; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU8; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFE4M3; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFE5M2; -}; - - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFE2M3; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFE3M2; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFE2M1; -}; -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFUE8M0; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kFUE4M3; -}; - - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU16; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU32; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kU64; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF16; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF32; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF64; -}; - -template <> struct NumericTypeMap > { - static NumericTypeID const kId = NumericTypeID::kCF16; -}; - -template <> struct NumericTypeMap > { - static NumericTypeID const kId = NumericTypeID::kCF32; -}; - -template <> struct NumericTypeMap > { - static NumericTypeID const kId = NumericTypeID::kCF64; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kBF16; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kTF32; -}; - - - - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF8; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF6; -}; - -template <> struct NumericTypeMap { - static NumericTypeID const kId = NumericTypeID::kF4; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kInvalid; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAdd; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kXorPopc; -}; - - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; -}; - -template <> struct MathOperationMap { - static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct LayoutMap; - -template <> struct LayoutMap { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; -}; - -template <> struct LayoutMap { - static LayoutTypeID const kId = LayoutTypeID::kRowMajor; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; -}; - -template <> struct LayoutMap { - static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; -}; - -template <> struct LayoutMap { - static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; -}; - -template <> struct LayoutMap> { - static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct OpcodeClassMap; - -template <> struct OpcodeClassMap { - static OpcodeClassID const kId = OpcodeClassID::kSimt; -}; - -template <> struct OpcodeClassMap { - static OpcodeClassID const kId = OpcodeClassID::kTensorOp; -}; - -template <> struct OpcodeClassMap { - static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; -}; - - -template <> struct OpcodeClassMap { - static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; -}; - - -template <> struct OpcodeClassMap { - static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ComplexTransformMap; - -template <> struct ComplexTransformMap { - static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; -}; - -template <> struct ComplexTransformMap { - static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ConvModeMap; - -template <> struct ConvModeMap { - static ConvModeID const kId = ConvModeID::kCrossCorrelation; -}; - -template <> struct ConvModeMap { - static ConvModeID const kId = ConvModeID::kConvolution; -}; - - -template struct ConvKindMap; - -template <> struct ConvKindMap { - static ConvKind const kId = ConvKind::kFprop; -}; - -template <> struct ConvKindMap { - static ConvKind const kId = ConvKind::kDgrad; -}; - -template <> struct ConvKindMap { - static ConvKind const kId = ConvKind::kWgrad; -}; - - -template struct IteratorAlgorithmMap; - -template <> struct IteratorAlgorithmMap { - static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; -}; - -template <> struct IteratorAlgorithmMap { - static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; -}; - -template <> struct IteratorAlgorithmMap { - static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; -}; - -template <> struct IteratorAlgorithmMap { - static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -TensorDescription make_TensorDescription(int alignment = 1) { - TensorDescription desc; - - desc.element = NumericTypeMap::kId; - desc.layout = LayoutMap::kId; - desc.alignment = alignment; - desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; - desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; - - return desc; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h deleted file mode 100644 index 76d8d0dfdb1aa6ed0324b9d6299b06ebf3f436d9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h +++ /dev/null @@ -1,377 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) - in CUTLASS Library. - - -*/ - -#pragma once -#include -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/device/rank_2k.h" -#include "cutlass/gemm/kernel/default_rank_2k_universal.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/core_io.h" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class Rank2KOperationBase : public Operation { -public: - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static BlasMode const kBlasMode = Operator::kBlasMode; - static int const kUpdateRank = Operator::kUpdateRank; - static FillMode const kFillModeC = Operator::kFillModeC; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - RankKDescription description_; - -public: - - /// Constructor - Rank2KOperationBase(char const *name = "unknown_rank_k") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.rank_k_kind = RankKKind::kUniversal; - description_.fill_mode = kFillModeC; - description_.blas_mode = kBlasMode; - description_.num_ranks = kUpdateRank; - - description_.kind = OperationKind::kRank2K; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::Rank2Kkernel::WarpCount::kM, - Operator::Rank2Kkernel::WarpCount::kN, - Operator::Rank2Kkernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - description_.transform_A = ComplexTransformMap::kId; - description_.transform_B = ComplexTransformMap::kId; - } - - /// Returns the description of the SYRK operation - virtual OperationDescription const & description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class Rank2KOperation : public Rank2KOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - static BlasMode const kBlasMode = Operator::kBlasMode; - static int const kUpdateRank = Operator::kUpdateRank; - static FillMode const kFillModeC = Operator::kFillModeC; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - Rank2KOperation(char const *name = "unknown_rank_2k"): - Rank2KOperationBase(name) { - - this->description_.rank_k_kind = RankKKind::kUniversal; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - RankKConfiguration const *configuration) { - - //operator_args.mode = configuration->mode; - - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda = int(configuration->lda); - operator_args.ldb = int(configuration->ldb); - operator_args.ldc = int(configuration->ldc); - operator_args.ldd = int(configuration->ldd); - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - RankKArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A = arguments->A; - operator_args.ptr_B = arguments->B; - operator_args.ptr_C = arguments->C; - operator_args.ptr_D = arguments->D; - - operator_args.batch_stride_A = arguments->batch_stride_A; - operator_args.batch_stride_B = arguments->batch_stride_B; - operator_args.batch_stride_C = arguments->batch_stride_C; - operator_args.batch_stride_D = arguments->batch_stride_D; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - RankKConfiguration const *configuration = - static_cast(configuration_ptr); - - RankKArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - //std::cout << "initialize() library::Rank2KOperation" << std::endl; - //print_operator_args(args); - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - //std::cout << "run() library::Rank2KOperation" << std::endl; - //print_operator_args(args); - status = op->run(stream); - - return status; - } - - /// Call print_operator_args from the Conv2dOperation::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "Rank2KOperation::OperatorArguments" << std::endl - << " problem_size:" << std::endl - << operator_args.problem_size << std::endl - << " epilogue (alpha, beta): " - << operator_args.epilogue.alpha << ", " - << operator_args.epilogue.beta << std::endl - << " ref_A (ptr, {stride}): " - << operator_args.ptr_A << ", {" - << operator_args.lda << "}" << std::endl - << " ref_B (ptr, {stride}): " - << operator_args.ptr_B << ", {" - << operator_args.ldb << "}" << std::endl - << " ref_C (ptr, {stride}): " - << operator_args.ptr_C << ", {" - << operator_args.ldc << "}" << std::endl - << " ref_D (ptr, {stride}): " - << operator_args.ptr_D << ", {" - << operator_args.ldd << "}" << std::endl; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h deleted file mode 100644 index 021f7f03fcc4449bdc2ef2c97e29fe0fead09a64..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h +++ /dev/null @@ -1,348 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all Rank K operation kinds (Syrk, Herk) - in CUTLASS Library. - - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/device/rank_k.h" -#include "cutlass/gemm/kernel/default_rank_k_universal.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class RankKOperationBase : public Operation { -public: - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementA; - using LayoutB = typename Operator::LayoutA; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static BlasMode const kBlasMode = Operator::kBlasMode; - static int const kUpdateRank = Operator::kUpdateRank; - static FillMode const kFillModeC = Operator::kFillModeC; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - RankKDescription description_; - -public: - - /// Constructor - RankKOperationBase(char const *name = "unknown_rank_k") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.rank_k_kind = RankKKind::kUniversal; - description_.fill_mode = kFillModeC; - description_.blas_mode = kBlasMode; - description_.num_ranks = kUpdateRank; - - description_.kind = OperationKind::kRankK; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::RankKkernel::WarpCount::kM, - Operator::RankKkernel::WarpCount::kN, - Operator::RankKkernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentA); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - description_.transform_A = ComplexTransformMap::kId; - description_.transform_B = ComplexTransformMap::kId; - } - - /// Returns the description of the SYRK operation - virtual OperationDescription const & description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class RankKOperation : public RankKOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementA; - using LayoutB = typename Operator::LayoutA; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - static BlasMode const kBlasMode = Operator::kBlasMode; - static int const kUpdateRank = Operator::kUpdateRank; - static FillMode const kFillModeC = Operator::kFillModeC; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - RankKOperation(char const *name = "unknown_rank_k"): - RankKOperationBase(name) { - - this->description_.rank_k_kind = RankKKind::kUniversal; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - RankKConfiguration const *configuration) { - - //operator_args.mode = configuration->mode; - - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda = int(configuration->lda); - operator_args.ldb = int(configuration->lda); - operator_args.ldc = int(configuration->ldc); - operator_args.ldd = int(configuration->ldd); - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - RankKArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A = arguments->A; - operator_args.ptr_C = arguments->C; - operator_args.ptr_D = arguments->D; - - operator_args.batch_stride_A = arguments->batch_stride_A; - operator_args.batch_stride_C = arguments->batch_stride_C; - operator_args.batch_stride_D = arguments->batch_stride_D; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - RankKConfiguration const *configuration = - static_cast(configuration_ptr); - - RankKArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h deleted file mode 100644 index 6e948540e3f29dceace42b5e8ef3f91118c01b37..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h +++ /dev/null @@ -1,294 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for reduction operation in CUTLASS Library. -*/ - -#pragma once -#include -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_clamp.h" -#include "cutlass/reduction/thread/reduction_operators.h" -#include "cutlass/reduction/device/reduce_split_k.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/core_io.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class ReductionOperation : public Operation { -public: - using Operator = Operator_; - - using ElementWorkspace = typename Operator::ElementWorkspace; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementOutput = typename Operator::ElementOutput; - - using ElementCompute = typename Operator::OutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - ReductionDescription description_; - -public: - - /// Constructor - ReductionOperation(char const *name = "unknown_reduction") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kReduction; - - description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); - - description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); - description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; - description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; - description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; - - description_.tile_description.minimum_compute_capability = 50; - description_.tile_description.maximum_compute_capability = 1024; - - description_.element_workspace = NumericTypeMap::kId; - description_.element_output = NumericTypeMap::kId; - description_.element_epilogue = NumericTypeMap::kId; - - } - - /// Returns the description of the Reduction operation - virtual OperationDescription const & description() const { - return description_; - } - - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - ReductionConfiguration const *configuration) { - - operator_args.problem_size = configuration->problem_size; - operator_args.partitions = configuration->partitions; - operator_args.partition_stride = configuration->partition_stride; - - operator_args.workspace = {nullptr, int(configuration->ldw)}; - operator_args.source = {nullptr, int(configuration->lds)}; - operator_args.destination = {nullptr, int(configuration->ldd)}; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - ReductionArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::OutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.output = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::OutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.output = params; - } - else { - return Status::kErrorInvalidProblem; - } - - operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); - operator_args.source.reset(static_cast(const_cast(arguments->source))); - operator_args.destination.reset(static_cast(const_cast(arguments->destination))); - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - ReductionConfiguration const *configuration = - static_cast(configuration_ptr); - - ReductionArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - //std::cout << "initialize library::Reduction" << std::endl; - //print_operator_args(args); - return op->initialize(args, device_workspace, stream); - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - //std::cout << "run library::Reduction" << std::endl; - //print_operator_args(args); - return op->run(stream); - } - - /// Call print_operator_args from the Reduction::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "Reduction::OperatorArguments" << std::endl - << " problem_size: " - << operator_args.problem_size << std::endl - << " partitions: " - << operator_args.partitions << std::endl - << " partition_stride: " - << operator_args.partition_stride << std::endl - << " epilogue (alpha, beta): " - << operator_args.output.alpha << ", " - << operator_args.output.beta << std::endl - << " workspace (ptr, stride): " - << operator_args.workspace.data() << ", " - << operator_args.workspace.stride(0) << std::endl - << " source (ptr, stride): " - << operator_args.source.data() << ", " - << operator_args.source.stride(0) << std::endl - << " destination (ptr, stride): " - << operator_args.destination.data() << ", " - << operator_args.destination.stride(0) << std::endl; - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h deleted file mode 100644 index 769da1c8515877536fd9b9fd72c836fd43ebd5d8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h +++ /dev/null @@ -1,453 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library -*/ - - - -#pragma once - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/util.h" -#include "cutlass/util/packed_stride.hpp" -#include "library_internal.h" - -#include "cutlass/util/reference/host/gett.hpp" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -namespace detail { -template -auto make_iterator(T* ptr) { - return cute::recast_ptr(ptr); -} -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - Provider Provider_, - typename ElementA_, - typename LayoutA_, - typename ElementSFA_, - typename ElementB_, - typename LayoutB_, - typename ElementSFB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ElementSFD_ = void, - typename LayoutSFD_ = LayoutC_, - int SFVecSize_ = 32, - int EpilogueSFVecSize_ = 0, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -class BlockScaledGemmReferenceOperation : public Operation { -public: - static Provider const kProvider = Provider_; - - using ElementA = ElementA_; - using LayoutA = LayoutA_; - using ElementSFA = ElementSFA_; - using ElementB = ElementB_; - using LayoutB = LayoutB_; - using ElementSFB = ElementSFB_; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using ElementD = ElementD_; - using ElementSFD = ElementSFD_; - using LayoutSFD = LayoutSFD_; - using ElementCompute = ElementCompute_; - using ElementAccumulator = ElementAccumulator_; - using ConvertOp = ConvertOp_; - using InnerProductOp = InnerProductOp_; - constexpr static int SFVecSize = SFVecSize_; - constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; - -protected: - - /// Storage for the name string - std::string name_; - - /// - BlockScaledGemmDescription description_; - -public: - - /// Constructor - BlockScaledGemmReferenceOperation() { - - // Basic information - description_.provider = kProvider; - description_.kind = OperationKind::kBlockScaledGemm; - description_.gemm_kind = GemmKind::kUniversal; - - // Tensor description - description_.A = make_TensorDescription(); - description_.SFA = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.SFB = make_TensorDescription(); - description_.C = make_TensorDescription(); - description_.D = make_TensorDescription(); - description_.SFD = make_TensorDescription(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - // Compute capability for gemm reference - description_.tile_description.minimum_compute_capability = - (kProvider == Provider::kReferenceDevice ? 50 : 0); - - description_.tile_description.maximum_compute_capability = 1024; - - description_.SFVecSize = SFVecSize; - description_.EpilogueSFVecSize = EpilogueSFVecSize; - - // Procedural name - std::stringstream ss; - - ss << "gemm" - << "_reference_" << to_string(description_.provider) - << "_" << to_string(description_.A.element) << to_string(description_.A.layout) - << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) - << "_" << to_string(description_.B.element) << to_string(description_.B.layout) - << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) - << "_" << to_string(description_.C.element) << to_string(description_.C.layout) - << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) - << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); - - name_ = ss.str(); - - description_.name = name_.c_str(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - virtual Status can_implement( - void const *configuration, - void const *arguments) const { - - return Status::kSuccess; - } - - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(GemmUniversalConfiguration); - } - - virtual uint64_t get_device_workspace_size( - void const *configuration, - void const *arguments = nullptr) const { - - return 0; - } - - virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - return Status::kSuccess; - } - - virtual Status run( - void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - using namespace cute; - - BlockScaledGemmArguments const &args = *static_cast(arguments); - - // Construct cute::Tensor A/B/C - - int M = args.problem_size.m(); - int N = args.problem_size.n(); - int K = args.problem_size.k(); - int L = args.batch_count; - - auto problem_shape_MNKL = cute::make_shape(M, N, K, L); - - auto alpha = *(static_cast(args.alpha)); - auto beta = *(static_cast(args.beta)); - - using StrideA = cutlass::gemm::TagToStrideA_t; - using StrideB = cutlass::gemm::TagToStrideB_t; - using StrideC = cutlass::gemm::TagToStrideC_t; - using StrideD = cutlass::gemm::TagToStrideC_t; - - auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - - using Sm1xxBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; - auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), - cute::make_layout(cute::make_shape(M, K, L), stride_a)); - auto SfA = make_tensor(static_cast(args.SFA), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); - - auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), - cute::make_layout(cute::make_shape(N, K, L), stride_b)); - auto SfB = make_tensor(static_cast(args.SFB), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); - - auto C = [&]() { - if constexpr (not is_same_v) { - return cute::make_tensor(detail::make_iterator(static_cast(args.C)), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - } - else { - return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - } - }(); - - auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - - cutlass::reference::host::GettBlockScalingMainloopParams - mainloop_params{A, SfA, B, SfB}; - - if constexpr (not is_same_v) { - - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< - EpilogueSFVecSize - >; - - auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); - - cutlass::reference::host::GettBlockScalingEpilogueParams< - ElementCompute, ElementAccumulator, ElementCompute, - decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> - epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - } - else { - // W/O SF generation - auto SfD = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L))); // not used. - cutlass::reference::host::GettBlockScalingEpilogueParams< - ElementCompute, ElementAccumulator, ElementCompute, - decltype(C), decltype(D), decltype(SfD)> - epilogue_params{alpha, beta, C, D, SfD}; - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - } - - return Status::kSuccess; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA_, - typename ElementSFA_, - typename ElementB_, - typename ElementSFB_, - typename ElementC_, - typename ElementCompute_, - typename ElementSFD_ = void, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - int SFVecSize = 32, - int EpilogueSFVecSize = SFVecSize, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_block_scaled_gemm_tn(Manifest &manifest) { -#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) - manifest.append(new BlockScaledGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ElementSFD_, - cutlass::layout::RowMajor, - SFVecSize, - EpilogueSFVecSize, - ConvertOp_, - InnerProductOp_ - >); -#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA_, - typename ElementSFA_, - typename ElementB_, - typename ElementSFB_, - typename ElementC_, - typename ElementCompute_, - typename ElementSFD_ = void, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - int SFVecSize = 32, - int EpilogueSFVecSize = SFVecSize, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_block_scaled_gemm(Manifest &manifest) { - /// - /// A is Row , B is Col - /// - manifest.append(new BlockScaledGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ElementSFD_, - cutlass::layout::RowMajor, - SFVecSize, - EpilogueSFVecSize, - ConvertOp_, - InnerProductOp_ - >); - manifest.append(new BlockScaledGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ElementSFD_, - cutlass::layout::RowMajor, - SFVecSize, - EpilogueSFVecSize, - ConvertOp_, - InnerProductOp_ - >); - /// - /// A is Col , B is Row - /// - manifest.append(new BlockScaledGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ElementSFD_, - cutlass::layout::RowMajor, - SFVecSize, - EpilogueSFVecSize, - ConvertOp_, - InnerProductOp_ - >); - manifest.append(new BlockScaledGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ElementSFD_, - cutlass::layout::RowMajor, - SFVecSize, - EpilogueSFVecSize, - ConvertOp_, - InnerProductOp_ - >); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h deleted file mode 100644 index fd988f899f563acfc6f8003bdb49523bca51d6d9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h +++ /dev/null @@ -1,807 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library -*/ - - - -#pragma once - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/util.h" -#include "cutlass/util/packed_stride.hpp" -#include "library_internal.h" - -#include "cutlass/util/reference/host/gett.hpp" -#include "cutlass/detail/blockwise_scale_layout.hpp" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - Provider Provider_, - typename ElementA_, - typename LayoutA_, - typename LayoutSFA_, - typename ElementSFA_, - typename ElementB_, - typename LayoutB_, - typename LayoutSFB_, - typename ElementSFB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -class BlockwiseGemmReferenceOperation : public Operation { -public: - static Provider const kProvider = Provider_; - - using ElementA = ElementA_; - using LayoutA = LayoutA_; - using ElementSFA = ElementSFA_; - using ElementB = ElementB_; - using LayoutB = LayoutB_; - using ElementSFB = ElementSFB_; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using ElementD = ElementD_; - using ElementCompute = ElementCompute_; - using ElementAccumulator = ElementAccumulator_; - using ConvertOp = ConvertOp_; - using InnerProductOp = InnerProductOp_; - -protected: - - /// Storage for the name string - std::string name_; - - /// - BlockwiseGemmDescription description_; - -public: - - /// Constructor - BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) - : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { - - // Basic information - description_.provider = kProvider; - description_.kind = OperationKind::kBlockwiseGemm; - description_.gemm_kind = GemmKind::kUniversal; - - // Tensor description - description_.A = make_TensorDescription(); - description_.SFA = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.SFB = make_TensorDescription(); - description_.C = make_TensorDescription(); - description_.D = make_TensorDescription(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - // Compute capability for gemm reference - description_.tile_description.minimum_compute_capability = - (kProvider == Provider::kReferenceDevice ? 50 : 0); - - description_.tile_description.maximum_compute_capability = 1024; - - description_.SFMVecSize = SFMVecSize; - description_.SFNVecSize = SFNVecSize; - description_.SFKVecSize = SFKVecSize; - - // Procedural name - std::stringstream ss; - - ss << "gemm" - << "_reference_" << to_string(description_.provider) - << "_" << to_string(description_.A.element) << to_string(description_.A.layout) - << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) - << "_" << to_string(description_.B.element) << to_string(description_.B.layout) - << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) - << "_" << to_string(description_.C.element) << to_string(description_.C.layout) - << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); - - name_ = ss.str(); - - description_.name = name_.c_str(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - virtual Status can_implement( - void const *configuration, - void const *arguments) const { - - return Status::kSuccess; - } - - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(GemmUniversalConfiguration); - } - - virtual uint64_t get_device_workspace_size( - void const *configuration, - void const *arguments = nullptr) const { - - return 0; - } - - virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - return Status::kSuccess; - } - - virtual Status run( - void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - using namespace cute; - - BlockwiseGemmArguments const &args = *static_cast(arguments); - - // Construct cute::Tensor A/B/C - - int M = args.problem_size.m(); - int N = args.problem_size.n(); - int K = args.problem_size.k(); - int L = args.batch_count; - - auto problem_shape_MNKL = cute::make_shape(M, N, K, L); - - auto alpha = *(static_cast(args.alpha)); - auto beta = *(static_cast(args.beta)); - - using StrideA = cutlass::gemm::TagToStrideA_t; - using StrideB = cutlass::gemm::TagToStrideB_t; - using StrideC = cutlass::gemm::TagToStrideC_t; - using StrideD = cutlass::gemm::TagToStrideC_t; - - auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; - auto A = cute::make_tensor(static_cast(args.A), - cute::make_layout(cute::make_shape(M, K, L), stride_a)); - auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); - - auto B = cute::make_tensor(static_cast(args.B), - cute::make_layout(cute::make_shape(N, K, L), stride_b)); - auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); - - auto C = [&]() { - if constexpr (not is_same_v) { - return cute::make_tensor(static_cast(args.C), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - } - else { - return cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L), stride_c)); - } - }(); - - auto D = cute::make_tensor(static_cast(args.D), - cute::make_layout(cute::make_shape(M, N, L), stride_d)); - - cutlass::reference::host::GettBlockScalingMainloopParams - mainloop_params{A, SfA, B, SfB}; - - // W/O SF generation - cutlass::reference::host::GettEpilogueParams< - ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, - decltype(C), decltype(D)> - epilogue_params{alpha, beta, C, D}; - - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - - return Status::kSuccess; - } - -private: - int SFMVecSize; - int SFNVecSize; - int SFKVecSize; -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA_, - typename ElementSFA_, - typename ElementB_, - typename ElementSFB_, - typename ElementC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - manifest.append(new BlockwiseGemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - ElementSFA_, - ElementB_, - cutlass::layout::RowMajor, - cutlass::layout::ColumnMajor, - ElementSFB_, - ElementC_, - cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(SFMVecSize, SFNVecSize, SFKVecSize)); - - -} - -template -void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 1 , 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 1, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 1, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 32, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 32, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 64, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 64, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 256, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 256, 128); - - - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 1 , 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 1, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 1 , 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 128, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 32, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 32, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 64, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 64, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 256, 128); - make_blockwise_gemm< - float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 256, 128); - - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 1 , 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 1, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 1, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 32, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 32, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 64, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 64, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 256, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 256, 128); - - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 1 , 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 1, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 1 , 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 128, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 32, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 32, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 64, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 64, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 128, 256, 128); - make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 1, 256, 128); - -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h deleted file mode 100644 index 240fe18d16a27778bf75e0c02f99d251c096353f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h +++ /dev/null @@ -1,636 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all CONV operation kinds in CUTLASS Library -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/util.h" -#include "library_internal.h" - -#include "cutlass/conv/convolution.h" -#include "cutlass/util/reference/host/convolution.h" -#include "cutlass/util/reference/device/convolution.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - Provider kProvider, - cutlass::conv::Operator ConvolutionalOperator, - int ConvDim, - typename ElementA_, - typename LayoutA_, - typename ElementB_, - typename LayoutB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -struct ConvReferenceDispatcher; - -/// Dispatcher for Conv2d (partially specialized for kConvDim == 2) -template < - Provider kProvider, - cutlass::conv::Operator kConvolutionalOperator, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator, - typename ConvertOp, - typename InnerProductOp -> -struct ConvReferenceDispatcher< - kProvider, - kConvolutionalOperator, - 2, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp> { - - static Status dispatch( - void const *configuration, - ElementA *ptr_A, - ElementB *ptr_B, - ElementC *ptr_C, - ElementC *ptr_D, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr - ) { - - Conv2dConfiguration const &config = - *static_cast(configuration); - - // TODO: make below code more general. It is fixed for NHWC now. - layout::TensorNHWC layout_a; - layout::TensorNHWC layout_b; - layout::TensorNHWC layout_c; - - layout_a.stride() = - make_Coord(int32_t(config.stride_a[0]), - int32_t(config.stride_a[1]), - int32_t(config.stride_a[2])); - - layout_b.stride() = - make_Coord(int32_t(config.stride_b[0]), - int32_t(config.stride_b[1]), - int32_t(config.stride_b[2])); - - layout_c.stride() = - make_Coord(int32_t(config.stride_c[0]), - int32_t(config.stride_c[1]), - int32_t(config.stride_c[2])); - - if (kProvider == Provider::kReferenceHost) { - - cutlass::reference::host::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC , - LayoutC, - ElementCompute, - ElementAccumulator, - ElementC, - ConvertOp, - InnerProductOp - >( - kConvolutionalOperator, - config.problem_size, - {ptr_A, layout_a}, - {ptr_B, layout_b}, - {ptr_C, layout_c}, - {ptr_D, layout_c}, - alpha, - beta - ); - - return Status::kSuccess; - } - else if (kProvider == Provider::kReferenceDevice) { - return cutlass::reference::device::Conv2d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp - >( - kConvolutionalOperator, - config.problem_size, - {ptr_A, layout_a}, - {ptr_B, layout_b}, - {ptr_C, layout_c}, - {ptr_D, layout_c}, - alpha, - beta, - stream - ); - } - return Status::kErrorNotSupported; - } -}; - -/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) -template < - Provider kProvider, - cutlass::conv::Operator kConvolutionalOperator, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator, - typename ConvertOp, - typename InnerProductOp -> -struct ConvReferenceDispatcher< - kProvider, - kConvolutionalOperator, - 3, - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp> { - - static Status dispatch( - void const *configuration, - ElementA *ptr_A, - ElementB *ptr_B, - ElementC *ptr_C, - ElementC *ptr_D, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr - ) { - - Conv3dConfiguration const &config = - *static_cast(configuration); - - ConvKind const conv_kind = ConvKindMap::kId; - - if (kProvider == Provider::kReferenceHost) { - cutlass::reference::host::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC , - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp - >( - kConvolutionalOperator, - config.problem_size, - {ptr_A, config.layout_a(conv_kind)}, - {ptr_B, config.layout_b(conv_kind)}, - {ptr_C, config.layout_c(conv_kind)}, - {ptr_D, config.layout_c(conv_kind)}, - alpha, - beta - ); - - return Status::kSuccess; - } - else if (kProvider == Provider::kReferenceDevice) { - return cutlass::reference::device::Conv3d< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp - >( - kConvolutionalOperator, - config.problem_size, - {ptr_A, config.layout_a(conv_kind)}, - {ptr_B, config.layout_b(conv_kind)}, - {ptr_C, config.layout_c(conv_kind)}, - {ptr_D, config.layout_c(conv_kind)}, - alpha, - beta, - stream - ); - } - return Status::kErrorNotSupported; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - Provider Provider_, - cutlass::conv::Operator ConvolutionalOperator, - int ConvDim, - typename ElementA_, - typename LayoutA_, - typename ElementB_, - typename LayoutB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -class ConvReferenceOperation : public Operation { -public: - static Provider const kProvider = Provider_; - static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; - static int const kConvDim = ConvDim; - - using ElementA = ElementA_; - using LayoutA = LayoutA_; - using ElementB = ElementB_; - using LayoutB = LayoutB_; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using ElementCompute = ElementCompute_; - using ElementAccumulator = ElementAccumulator_; - using ConvertOp = ConvertOp_; - using InnerProductOp = InnerProductOp_; - -protected: - - /// Storage for the name string - std::string name_; - - /// - ConvDescription description_; - -public: - - /// Constructor - ConvReferenceOperation() { - - // Basic information - description_.provider = kProvider; - description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); - description_.conv_kind = ConvKindMap::kId; - description_.conv_dim = kConvDim; - - // Tensor description - description_.A = make_TensorDescription(); - description_.B = make_TensorDescription(); - description_.C = make_TensorDescription(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - // Iterator algorithm for convolution reference - description_.iterator_algorithm = IteratorAlgorithmID::kNone; - - // Compute capability for convolution reference - description_.tile_description.minimum_compute_capability = - (kProvider == Provider::kReferenceDevice ? 50 : 0); - - description_.tile_description.maximum_compute_capability = 1024; - - // Procedural name - std::stringstream ss; - - ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) - << "_reference_" << to_string(description_.provider) - << "_" << to_string(description_.A.element) << to_string(description_.A.layout) - << "_" << to_string(description_.B.element) << to_string(description_.B.layout) - << "_" << to_string(description_.C.element) << to_string(description_.C.layout) - << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); - - name_ = ss.str(); - - description_.name = name_.c_str(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - virtual Status can_implement( - void const *configuration, - void const *arguments) const { - - return Status::kSuccess; - } - - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - switch (kConvDim) { - case 2: - return sizeof(Conv2dConfiguration); - case 3: - return sizeof(Conv3dConfiguration); - default: - break; - } - - return 0; - } - - virtual uint64_t get_device_workspace_size( - void const *configuration, - void const *arguments = nullptr) const { - - return 0; - } - - virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); - - return Status::kSuccess; - } - - virtual Status run( - void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - ConvArguments const &args = *static_cast(arguments); - - ElementCompute alpha; - ElementCompute beta; - - alpha = *static_cast(args.alpha); - beta = *static_cast(args.beta); - - // TODO - respect pointer mode - - // Invoke 2D or 3D convolution - return detail::ConvReferenceDispatcher< - kProvider, - kConvolutionalOperator, - kConvDim, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp - >::dispatch( - host_workspace, - static_cast(const_cast(args.A)), - static_cast(const_cast(args.B)), - static_cast(const_cast(args.C)), - static_cast(args.D), - alpha, - beta, - stream - ); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Constructs Fprop reference operators. -template < - int kConvDim, - typename ElementA_, - typename LayoutA_, - typename ElementB_, - typename LayoutB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_conv_fprop(Manifest &manifest) { -#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) - manifest.append(new ConvReferenceOperation< - Provider::kReferenceHost, - cutlass::conv::Operator::kFprop, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); - - manifest.append(new ConvReferenceOperation< - Provider::kReferenceDevice, - cutlass::conv::Operator::kFprop, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); -#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) -} - -/// Constructs Dgrad and Wgrad reference operators. -template < - int kConvDim, - typename ElementA_, - typename LayoutA_, - typename ElementB_, - typename LayoutB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_conv_backwards(Manifest &manifest) { -#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) - manifest.append(new ConvReferenceOperation< - Provider::kReferenceHost, - cutlass::conv::Operator::kDgrad, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); - - manifest.append(new ConvReferenceOperation< - Provider::kReferenceDevice, - cutlass::conv::Operator::kDgrad, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); - - manifest.append(new ConvReferenceOperation< - Provider::kReferenceHost, - cutlass::conv::Operator::kWgrad, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); - - manifest.append(new ConvReferenceOperation< - Provider::kReferenceDevice, - cutlass::conv::Operator::kWgrad, - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >); -#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) -} - -/// Six operators for the price of one. -template < - int kConvDim, - typename ElementA_, - typename LayoutA_, - typename ElementB_, - typename LayoutB_, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_conv_all(Manifest &manifest) { - - make_conv_fprop< - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_conv_backwards< - kConvDim, - ElementA_, LayoutA_, - ElementB_, LayoutB_, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ConvertOp_, - InnerProductOp_ - >(manifest); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h deleted file mode 100644 index e07158b0602eef1d71cfdca95323b3da60553747..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h +++ /dev/null @@ -1,543 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines reference operations for GEMM operation kinds in CUTLASS Library -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/util.h" -#include "library_internal.h" - -#include "cutlass/util/reference/host/gemm_complex.h" -#include "cutlass/util/reference/device/gemm_complex.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - Provider Provider_, - typename ElementA_, - typename LayoutA_, - cutlass::ComplexTransform TransformA, - typename ElementB_, - typename LayoutB_, - cutlass::ComplexTransform TransformB, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -class GemmReferenceOperation : public Operation { -public: - static Provider const kProvider = Provider_; - - using ElementA = ElementA_; - using LayoutA = LayoutA_; - using TensorRefA = TensorRef; - static cutlass::ComplexTransform const kTransformA = TransformA; - using ElementB = ElementB_; - using LayoutB = LayoutB_; - using TensorRefB = TensorRef; - static cutlass::ComplexTransform const kTransformB = TransformB; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using ElementD = ElementD_; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - using ElementCompute = ElementCompute_; - using ElementAccumulator = ElementAccumulator_; - using ConvertOp = ConvertOp_; - using InnerProductOp = InnerProductOp_; - -protected: - - /// Storage for the name string - std::string name_; - - /// - GemmDescription description_; - -public: - - /// Constructor - GemmReferenceOperation() { - - // Basic information - description_.provider = kProvider; - description_.kind = OperationKind::kGemm; - description_.gemm_kind = GemmKind::kUniversal; - - // Tensor description - description_.A = make_TensorDescription(); - description_.transform_A = ComplexTransformMap::kId; - description_.B = make_TensorDescription(); - description_.transform_B = ComplexTransformMap::kId; - description_.C = make_TensorDescription(); - description_.D = make_TensorDescription(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - // Compute capability for gemm reference - description_.tile_description.minimum_compute_capability = - (kProvider == Provider::kReferenceDevice ? 50 : 0); - - description_.tile_description.maximum_compute_capability = 1024; - - // Procedural name - std::stringstream ss; - - ss << "gemm" - << "_reference_" << to_string(description_.provider) - << "_" << to_string(description_.A.element) << to_string(description_.A.layout) - << "_" << to_string(description_.B.element) << to_string(description_.B.layout) - << "_" << to_string(description_.C.element) << to_string(description_.C.layout) - << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); - - name_ = ss.str(); - - description_.name = name_.c_str(); - - // Epilogue compute and accumulator type description - description_.element_epilogue = NumericTypeMap::kId; - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - } - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - - virtual Status can_implement( - void const *configuration, - void const *arguments) const { - - return Status::kSuccess; - } - - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(GemmUniversalConfiguration); - } - - virtual uint64_t get_device_workspace_size( - void const *configuration, - void const *arguments = nullptr) const { - - return 0; - } - - virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); - - return Status::kSuccess; - } - - virtual Status run( - void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - GemmUniversalConfiguration const &config = *static_cast(host_workspace); - GemmUniversalArguments const &args = *static_cast(arguments); - - TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; - TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; - TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; - TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; - - if (kProvider == Provider::kReferenceHost) { - - cutlass::reference::host::GemmComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ElementD, - ConvertOp, - InnerProductOp - >( - config.problem_size, - *static_cast(args.alpha), - ref_A, - kTransformA, - ref_B, - kTransformB, - *static_cast(args.beta), - ref_C, - ref_D, - ElementAccumulator(), - ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), - args.batch_stride_A, - args.batch_stride_B, - args.batch_stride_C, - args.batch_stride_D - ); - - return Status::kSuccess; - } - else if (kProvider == Provider::kReferenceDevice) { - - cutlass::reference::device::GemmComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ElementD, - ConvertOp, - InnerProductOp - >( - config.problem_size, - *static_cast(args.alpha), - ref_A, - kTransformA, - ref_B, - kTransformB, - *static_cast(args.beta), - ref_C, - ref_D, - ElementAccumulator(), - ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), - args.batch_stride_A, - args.batch_stride_B, - args.batch_stride_C, - args.batch_stride_D - ); - - return Status::kSuccess; - } - - return Status::kErrorNotSupported; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA_, - typename LayoutA_, - cutlass::ComplexTransform TransformA, - typename ElementB_, - typename LayoutB_, - cutlass::ComplexTransform TransformB, - typename ElementC_, - typename LayoutC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_gemm(Manifest &manifest) { -#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) - manifest.append(new GemmReferenceOperation< - Provider::kReferenceHost, - ElementA_, LayoutA_, TransformA, - ElementB_, LayoutB_, TransformB, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >); - - manifest.append(new GemmReferenceOperation< - Provider::kReferenceDevice, - ElementA_, LayoutA_, TransformA, - ElementB_, LayoutB_, TransformB, - ElementC_, LayoutC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >); -#endif -} - -/// Helper to create NN, NT, TN, and TT GEMM layouts. -template < - typename ElementA_, cutlass::ComplexTransform TransformA, - typename ElementB_, cutlass::ComplexTransform TransformB, - typename ElementC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_gemm_canonical_layouts(Manifest &manifest) { - - // M Major outputs - make_gemm< - ElementA_, cutlass::layout::ColumnMajor, TransformA, - ElementB_, cutlass::layout::ColumnMajor, TransformB, - ElementC_, cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::ColumnMajor, TransformA, - ElementB_, cutlass::layout::RowMajor, TransformB, - ElementC_, cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::RowMajor, TransformA, - ElementB_, cutlass::layout::ColumnMajor, TransformB, - ElementC_, cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::RowMajor, TransformA, - ElementB_, cutlass::layout::RowMajor, TransformB, - ElementC_, cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - // N Major outputs - make_gemm< - ElementA_, cutlass::layout::ColumnMajor, TransformA, - ElementB_, cutlass::layout::ColumnMajor, TransformB, - ElementC_, cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::ColumnMajor, TransformA, - ElementB_, cutlass::layout::RowMajor, TransformB, - ElementC_, cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::RowMajor, TransformA, - ElementB_, cutlass::layout::ColumnMajor, TransformB, - ElementC_, cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm< - ElementA_, cutlass::layout::RowMajor, TransformA, - ElementB_, cutlass::layout::RowMajor, TransformB, - ElementC_, cutlass::layout::RowMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); -} - - -/// Helper to create TN and interleaved layouts GEMM layouts. -template < - int InterleaveK, - typename ElementA_, - typename ElementB_, - typename ElementC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_gemm_interleaved_layouts(Manifest &manifest) { - - make_gemm< - ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, - ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, - ElementC_, cutlass::layout::ColumnMajor, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - -} - -/// Helper to real-valued GEMM with canonical layouts -template < - typename ElementA_, - typename ElementB_, - typename ElementC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_gemm_real_canonical_layouts(Manifest &manifest) { - make_gemm_canonical_layouts< - ElementA_, cutlass::ComplexTransform::kNone, - ElementB_, cutlass::ComplexTransform::kNone, - ElementC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); -} - -// Helper to create all complex transformation permutations -template < - typename ElementA_, - typename ElementB_, - typename ElementC_, - typename ElementCompute_, - typename ElementAccumulator_ = ElementCompute_, - typename ElementD_ = ElementC_, - typename ConvertOp_ = NumericConverter, - typename InnerProductOp_ = multiply_add -> -void make_gemm_complex_canonical_layouts(Manifest &manifest) { - - make_gemm_canonical_layouts< - ElementA_, cutlass::ComplexTransform::kNone, - ElementB_, cutlass::ComplexTransform::kNone, - ElementC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm_canonical_layouts< - ElementA_, cutlass::ComplexTransform::kConjugate, - ElementB_, cutlass::ComplexTransform::kConjugate, - ElementC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm_canonical_layouts< - ElementA_, cutlass::ComplexTransform::kNone, - ElementB_, cutlass::ComplexTransform::kConjugate, - ElementC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); - - make_gemm_canonical_layouts< - ElementA_, cutlass::ComplexTransform::kConjugate, - ElementB_, cutlass::ComplexTransform::kNone, - ElementC_, - ElementCompute_, - ElementAccumulator_, - ElementD_, - ConvertOp_, - InnerProductOp_ - >(manifest); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp deleted file mode 100644 index 01caa11e229ffd9109b0973dcca01064df448fa3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp +++ /dev/null @@ -1,504 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/collective.hpp" -#include "cutlass/array.h" -#include "cutlass/array_subbyte.h" -#include "cutlass/library/library.h" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor -#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter -#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride -#include "gemm_operation_3x.hpp" -#include "library_internal.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/mixed_dtype_utils.hpp" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/tensor_fill.h" -#include "cutlass/util/reference/device/tensor_compare.h" -#include "cute/tensor.hpp" -#include - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Limitation & Assumptions: -// 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, -// and lda is m if the tensor is m-major. -// 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. -// This is because we can not get the problem_count information in the get_device_workspace_size(). -// But I can promise it will use at least 192MB memory if we enable circular buffer. -template -class SparseGemmUniversal3xOperation : public GemmOperation3xBase { -public: - - using Operator = Operator_; - using OperatorArguments = typename Operator::Arguments; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementD = typename Operator::ElementD; - using LayoutD = typename Operator::LayoutD; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using CollectiveMainloop = typename Operator::CollectiveMainloop; - using CollectiveEpilogue = typename Operator::CollectiveEpilogue; - using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; - - static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - - static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), - "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); - - static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - - using ElementE = typename CollectiveMainloop::ElementE; - using LayoutE = typename CollectiveMainloop::LayoutE; - using SparseConfig = typename CollectiveMainloop::SparseConfig; - using LayoutATag = decltype(SparseConfig::deduce_layoutA_tag(typename CollectiveMainloop::LayoutA{})); - using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< - cute::Shape, - ElementA, - LayoutATag, - SparseConfig>; - using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< - cute::Shape, - ElementA, - LayoutATag, - SparseConfig, - typename Operator::ArchTag>; - - using Compressor = cutlass::transform::device::TransformUniversalAdapter; - -public: - - /// Constructor - SparseGemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) {} - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { - // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides - // Do nothing here and construct kernel arguments in update_arguments_ instead - // We also cannot construct TMA descriptors without all the arguments available - - operator_args.mode = configuration->mode; - return Status::kSuccess; - } - - template - struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { - if (arguments.pointer_mode == ScalarPointerMode::kHost) { - fusion_args.alpha = *static_cast(arguments.alpha); - fusion_args.beta = *static_cast(arguments.beta); - fusion_args.alpha_ptr = nullptr; - fusion_args.beta_ptr = nullptr; - - return Status::kSuccess; - } - else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; - fusion_args.alpha_ptr = static_cast(arguments.alpha); - fusion_args.beta_ptr = static_cast(arguments.beta); - - return Status::kSuccess; - } - else { - return Status::kErrorInvalidProblem; - } - } - }; - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmUniversalArguments const *arguments, - CompressorUtility const& compressor_utility, - void* device_a_compressed_ptr = nullptr, - void* device_e_ptr = nullptr) { - Status status = Status::kSuccess; - - status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, *arguments); - if (status != Status::kSuccess) { - return status; - } - - operator_args.problem_shape = cute::make_shape( - arguments->problem_size.m(), - arguments->problem_size.n(), - arguments->problem_size.k(), - arguments->batch_count); - - // update arguments - - if constexpr (IsRuntimeDataType) { - using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; - using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - - std::unordered_map mapping = { - {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, - {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, - {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, - {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} - }; - - auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); - auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); - - if (iter_runtime_a != mapping.end()) { - operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; - } else { - assert("invalid runtime argument for datatype A!"); - } - - if (iter_runtime_b != mapping.end()) { - operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; - } else { - assert("invalid runtime argument for datatype B!"); - } - - } - else { - operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); - operator_args.mainloop.ptr_B = static_cast(arguments->B); - } - operator_args.mainloop.ptr_E = static_cast(device_e_ptr); - operator_args.epilogue.ptr_C = static_cast(arguments->C); - operator_args.epilogue.ptr_D = static_cast(arguments->D); - - operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); - operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); - operator_args.mainloop.dB = cute::make_int_tuple_from( - arguments->ldb, arguments->batch_stride_B); - operator_args.epilogue.dC = cute::make_int_tuple_from( - arguments->ldc, arguments->batch_stride_C); - operator_args.epilogue.dD = operator_args.epilogue.dC; - - /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ - operator_args.hw_info.sm_count = arguments->sm_count; - if constexpr (!std::is_const_v) { - operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; - } - - if constexpr (!std::is_const_v) { - using Enum_t = decltype(operator_args.scheduler.raster_order); - switch (arguments->raster_order) { - case RasterOrder::kAlongN: - operator_args.scheduler.raster_order = Enum_t::AlongN; - break; - case RasterOrder::kAlongM: - operator_args.scheduler.raster_order = Enum_t::AlongM; - break; - default: - operator_args.scheduler.raster_order = Enum_t::Heuristic; - } - } - - if constexpr (std::is_same_v) { - operator_args.scheduler.splits = arguments->split_k_slices; - } - - if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { - operator_args.hw_info.cluster_shape = dim3( - arguments->cluster_shape.m(), - arguments->cluster_shape.n(), - arguments->cluster_shape.k()); - operator_args.hw_info.cluster_shape_fallback = dim3( - arguments->cluster_shape_fallback.m(), - arguments->cluster_shape_fallback.n(), - arguments->cluster_shape_fallback.k()); - } - return status; - } - -public: - - /// Returns success if the operation can proceed - Status can_implement( - void const *configuration_ptr, void const *arguments_ptr) const override { - - GemmUniversalConfiguration const *configuration = - static_cast(configuration_ptr); - GemmUniversalArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - auto problem_shape_MNKL = cute::make_shape( - configuration->problem_size.m(), - configuration->problem_size.n(), - configuration->problem_size.k(), - configuration->batch_count); - - const int M = configuration->problem_size.m(); - const int N = configuration->problem_size.n(); - const int K = configuration->problem_size.k(); - const int L = configuration->batch_count; - using StrideA = typename CompressorUtility::StrideA; - auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - compressor_utility.set_problem_size(problem_shape_MNKL, dA); - auto status = update_arguments_(args, arguments, compressor_utility); - if (status != Status::kSuccess) { - return status; - } - - // can_implement rules may need access to problem shape - args.problem_shape = problem_shape_MNKL; - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - uint64_t get_host_workspace_size(void const *) const override { - // Memory to hold operator - host_op_workspace_size = sizeof(Operator); - - // Memory to hold result of `.structure_sparse_zero_mask_fill()` - tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); - - // NOTE: order here is the order of workspace partition - const uint64_t size = host_op_workspace_size + tensor_a_size; - - return size; - } - - /// Gets the device-side workspace - uint64_t get_device_workspace_size( - void const *configuration_ptr,void const *arguments_ptr) const override { - - OperatorArguments args; - auto status = update_arguments_( - args, static_cast(arguments_ptr), compressor_utility); - if (status != Status::kSuccess) { - return 0; - } - - typename Compressor::Arguments compress_arguments { - {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, - {/*Empty Not Use*/}, - {/*Empty Not Use*/} }; - - // Size for one iteration - // For multi-iteration, will need to multiply result of this function w/ actual problem_count - tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); - tensor_e_size = compressor_utility.get_tensor_E_bytes(); - device_op_workspace_size = Operator::get_workspace_size(args); - device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); - - // NOTE: order here is the order of workspace partition - device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; - - return device_per_iter_workspace_size; - } - - /// Initializes the workspace - Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const override { - return Status::kErrorInternal; - } - - Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, - uint8_t **profiler_workspaces, - int problem_count_from_profiler, - cudaStream_t stream = nullptr) { - - iter_idx.resize(static_cast(configuration)->device_count, 0); - - // Set problem_count. - problem_count = problem_count_from_profiler; - - // * Host Ptr - auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); - auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; - - // * Construct Op - Operator *op = new (host_op_workspace_ptr) Operator; - - // * Device Ptr (1st iteration) - // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | - // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iter1 = static_cast(device_workspace); - auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; - auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; - auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; - auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; - - // * Device A Raw Ptr - auto* device_a_raw_ptr = profiler_workspaces[0]; - - // * Random fill 50% of TensorA w/ zero following the structured sparse requirement - CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); - compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); - CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); - - CUDA_CHECK(cudaGetLastError()); - - // * Compress DTensorA and get DTensorAC & DTensorE - cutlass::KernelHardwareInfo hw_info; - CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Compressor::Arguments arguments{ - {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, - {device_a_raw_ptr, - compressor_utility.dA, - device_a_compressed_ptr_iter1, - device_e_ptr_iter1}, - {hw_info} - }; - - cutlass::Status status {cutlass::Status::kSuccess }; - - Compressor compressor_op; - status = compressor_op.can_implement(arguments); - if (status != Status::kSuccess) { - return status; - } - - status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); - if (status != Status::kSuccess) { - return status; - } - - status = compressor_op.run(stream); - if (status != Status::kSuccess) { - return status; - } - - // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE - for (int iter_i = 1; iter_i < problem_count; iter_i++) { - // * Device AC E Ptr per iteration - // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | - // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; - auto* device_op_workspace_ptr = device_ptr_iteri; - auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; - auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; - auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - - CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); - } - - CUDA_CHECK(cudaStreamSynchronize(stream)); - - CUDA_CHECK(cudaGetLastError()); - - return Status::kSuccess; - } - - /// Runs the kernel - Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const override { - - OperatorArguments operator_args; - - - const auto device_index = static_cast(arguments_ptr)->device_index; - - auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; - auto* device_op_workspace_ptr = device_ptr_iteri; - auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; - auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; - auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; - - Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, - static_cast(arguments_ptr)->use_pdl); - return status; - } - -private: - // Variables that must change in the const functions. - mutable CompressorUtility compressor_utility; - mutable int problem_count = 1; - mutable std::vector iter_idx; - - mutable uint64_t tensor_ac_size = 0; - mutable uint64_t tensor_e_size = 0; - mutable uint64_t tensor_a_size = 0; - mutable uint64_t host_op_workspace_size = 0; - mutable uint64_t device_compress_workspace_size = 0; - mutable uint64_t device_op_workspace_size = 0; - mutable uint64_t device_per_iter_workspace_size = 0; -}; -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::library - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h deleted file mode 100644 index c95d238a81f825dbbeae689ec452467cc8ca3afa..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h +++ /dev/null @@ -1,382 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all Symm operation kinds (Symm, Hemm) - in CUTLASS Library. - - -*/ - -#pragma once -#include -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/device/symm.h" -#include "cutlass/gemm/kernel/default_symm_universal.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" -#include "cutlass/core_io.h" -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class SymmOperationBase : public Operation { -public: - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - static BlasMode const kBlasMode = Operator::kBlasMode; - static SideMode const kSideModeA = Operator::kSideModeA; - static FillMode const kFillModeA = Operator::kFillModeA; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - SymmDescription description_; - -public: - - /// Constructor - SymmOperationBase(char const *name = "unknown_symm") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.symm_kind = SymmKind::kUniversal; - description_.side_mode = kSideModeA; - description_.fill_mode = kFillModeA; - description_.blas_mode = kBlasMode; - - description_.kind = OperationKind::kSymm; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::SymmKernel::WarpCount::kM, - Operator::SymmKernel::WarpCount::kN, - Operator::SymmKernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.C = make_TensorDescription(Operator::kAlignmentC); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - } - - /// Returns the description of the SYMM operation - virtual OperationDescription const & description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class SymmOperation : public SymmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - static BlasMode const kBlasMode = Operator::kBlasMode; - static SideMode const kSideModeA = Operator::kSideModeA; - static FillMode const kFillModeA = Operator::kFillModeA; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - SymmOperation(char const *name = "unknown_symm"): - SymmOperationBase(name) { - - this->description_.symm_kind = SymmKind::kUniversal; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - SymmConfiguration const *configuration) { - - //operator_args.mode = configuration->mode; - - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda = int(configuration->lda); - operator_args.ldb = int(configuration->ldb); - operator_args.ldc = int(configuration->ldc); - operator_args.ldd = int(configuration->ldd); - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - SymmArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A = arguments->A; - operator_args.ptr_B = arguments->B; - operator_args.ptr_C = arguments->C; - operator_args.ptr_D = arguments->D; - - operator_args.batch_stride_A = arguments->batch_stride_A; - operator_args.batch_stride_B = arguments->batch_stride_B; - operator_args.batch_stride_C = arguments->batch_stride_C; - operator_args.batch_stride_D = arguments->batch_stride_D; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - SymmConfiguration const *configuration = - static_cast(configuration_ptr); - - SymmArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - //std::cout << "initialize() library::SymmOperation" << std::endl; - //print_operator_args(args); - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && - std::is_same::value) || - (kSideModeA == SideMode::kRight && - std::is_same::value); - if (need_swapped_matrices) { - status = op->update(args.swapped_matrices(), device_workspace); - } else { - status = op->update(args, device_workspace); - } - - if (status != Status::kSuccess) { - return status; - } - - //std::cout << "run() library::SymmOperation" << std::endl; - //print_operator_args(args); - status = op->run(stream); - - return status; - } - - /// Call print_operator_args from the Conv2dOperation::initialize() - // to dump arguments passed on to cutlass operator for debugging - void print_operator_args(OperatorArguments &operator_args) const { - std::cout << "SymmOperation::OperatorArguments" << std::endl - << " problem_size:" << std::endl - << operator_args.problem_size << std::endl - << " epilogue (alpha, beta): " - << operator_args.epilogue.alpha << ", " - << operator_args.epilogue.beta << std::endl - << " ref_A (ptr, {stride}): " - << operator_args.ptr_A << ", {" - << operator_args.lda << "}" << std::endl - << " ref_B (ptr, {stride}): " - << operator_args.ptr_B << ", {" - << operator_args.ldb << "}" << std::endl - << " ref_C (ptr, {stride}): " - << operator_args.ptr_C << ", {" - << operator_args.ldc << "}" << std::endl - << " ref_D (ptr, {stride}): " - << operator_args.ptr_D << ", {" - << operator_args.ldd << "}" << std::endl; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h deleted file mode 100644 index d419723791ace5d90eb7955223be9db72bbc2c3c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h +++ /dev/null @@ -1,350 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines operations for all TRMM operation kinds in CUTLASS Library. - - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/device/trmm.h" -#include "cutlass/gemm/kernel/default_trmm_universal.h" -#include "cutlass/gemm/kernel/trmm_universal.h" - -#include "cutlass/library/library.h" -#include "library_internal.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class TrmmOperationBase : public Operation { -public: - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - static SideMode const kSideMode = Operator::kSideMode; - static FillMode const kFillMode = Operator::kFillMode; - static DiagType const kDiagType = Operator::kDiagType; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - TrmmDescription description_; - -public: - - /// Constructor - TrmmOperationBase(char const *name = "unknown_trmm") { - - description_.name = name; - description_.provider = Provider::kCUTLASS; - description_.kind = OperationKind::kTrmm; - description_.trmm_kind = TrmmKind::kUniversal; - description_.side_mode = kSideMode; - description_.fill_mode = kFillMode; - description_.diag_type = kDiagType; - - description_.tile_description.threadblock_shape = make_Coord( - Operator::ThreadblockShape::kM, - Operator::ThreadblockShape::kN, - Operator::ThreadblockShape::kK); - - description_.tile_description.threadblock_stages = Operator::kStages; - - description_.tile_description.warp_count = make_Coord( - Operator::TrmmKernel::WarpCount::kM, - Operator::TrmmKernel::WarpCount::kN, - Operator::TrmmKernel::WarpCount::kK); - - description_.tile_description.math_instruction.instruction_shape = make_Coord( - Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); - - description_.tile_description.math_instruction.element_accumulator = - NumericTypeMap::kId; - - description_.tile_description.math_instruction.opcode_class = - OpcodeClassMap::kId; - - description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; - - description_.tile_description.minimum_compute_capability = - ArchMap::kMin; - - description_.tile_description.maximum_compute_capability = - ArchMap::kMax; - - description_.A = make_TensorDescription(Operator::kAlignmentA); - description_.B = make_TensorDescription(Operator::kAlignmentB); - description_.D = make_TensorDescription(Operator::kAlignmentC); - description_.element_epilogue = NumericTypeMap::kId; - - description_.split_k_mode = SplitKMode::kNone; - description_.transform_A = ComplexTransformMap::kId; - } - - /// Returns the description of the TRMM operation - virtual OperationDescription const & description() const { - return description_; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class TrmmOperation : public TrmmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - static SideMode const kSideMode = Operator::kSideMode; - static FillMode const kFillMode = Operator::kFillMode; - static DiagType const kDiagType = Operator::kDiagType; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -public: - - /// Constructor - TrmmOperation(char const *name = "unknown_trmm"): - TrmmOperationBase(name) { - - this->description_.trmm_kind = TrmmKind::kUniversal; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - TrmmConfiguration const *configuration) { - - //operator_args.mode = configuration->mode; - - operator_args.problem_size = configuration->problem_size; - operator_args.batch_count = configuration->batch_count; - - operator_args.lda = int(configuration->lda); - operator_args.ldb = int(configuration->ldb); - operator_args.ldd = int(configuration->ldd); - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - TrmmArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - // update arguments - operator_args.ptr_A = arguments->A; - operator_args.ptr_B = arguments->B; - operator_args.batch_stride_A = arguments->batch_stride_A; - operator_args.batch_stride_B = arguments->batch_stride_B; - operator_args.ptr_D = arguments->D; - operator_args.batch_stride_D = arguments->batch_stride_D; - - if (arguments->use_pdl) { - return Status::kErrorNotSupported; - } - - return Status::kSuccess; - } - -public: - - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - TrmmConfiguration const *configuration = - static_cast(configuration_ptr); - - TrmmArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr, - void const *arguments_ptr = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - uint64_t size = Operator::get_workspace_size(args); - - return size; - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - status = op->initialize(args, device_workspace, stream); - - return status; - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - bool need_swapped_matrices = (kSideMode == SideMode::kLeft && - std::is_same::value) || - (kSideMode == SideMode::kRight && - std::is_same::value); - if (need_swapped_matrices) { - status = op->update(args.swapped_matrices(), device_workspace); - } else { - status = op->update(args, device_workspace); - } - - if (status != Status::kSuccess) { - return status; - } - - status = op->run(stream); - - return status; - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h deleted file mode 100644 index 5d500d9149bf645eadf8110d98612c40882d742c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h +++ /dev/null @@ -1,330 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Blockscale Gemm Profiler -*/ - - - -#pragma once - -#include -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "reduction_operation_profiler.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class BlockScaledGemmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct GemmProblem { - - cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; - - /// For profiling purposes - std::vector problem_sizes; - std::vector> leading_dims; - std::vector> preferred_clusters; - std::vector> fallback_clusters; - std::vector raster_orders; - std::vector swizzle_sizes; - - int64_t m{16}; - int64_t n{16}; - int64_t k{16}; - - - int cluster_m{1}; - int cluster_n{1}; - int cluster_k{1}; - int cluster_m_fallback{1}; - int cluster_n_fallback{1}; - int cluster_k_fallback{1}; - - - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - std::vector alpha; - std::vector beta; - - cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; - int split_k_slices{1}; - int batch_count{1}; - - cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; - int swizzle_size{1}; - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; - cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; - - - // gemm with parallel interleaved reduction - // gemm epilogue (alpha, beta) = (1.0, 0.0) - // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) - std::vector alpha_one; - std::vector beta_zero; - - bool use_pdl{false}; - // - // Methods - // - - /// Parses the problem - Status parse( - library::BlockScaledGemmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - int64_t bytes_with_problem_shape( - library::BlockScaledGemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - int64_t flops_with_problem_shape( - library::BlockScaledGemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - /// Total number of bytes loaded - int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::BlockScaledGemmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct GemmWorkspace { - - DeviceAllocation *A{nullptr}; - DeviceAllocation *SFA{nullptr}; - DeviceAllocation *B{nullptr}; - DeviceAllocation *SFB{nullptr}; - DeviceAllocation *C{nullptr}; - DeviceAllocation *Computed{nullptr}; - DeviceAllocation *Reference{nullptr}; - DeviceAllocation *Computed_SFD{nullptr}; - DeviceAllocation *Reference_SFD{nullptr}; - DeviceAllocation *Norm_constant{nullptr}; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - int problem_count{1}; - - library::GemmUniversalConfiguration configuration; - library::BlockScaledGemmArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - /// Library configuration and arguments for reduction operator - library::ReductionConfiguration reduction_configuration; - library::ReductionArguments reduction_arguments; - - /// Buffer used for the cutlass reduction operations' host workspace - std::vector reduction_host_workspace; - - cudaStream_t stream; - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - GemmProblem problem_; - - /// Device memory allocations - GemmWorkspace gemm_workspace_; - - /// CUTLASS parallel reduction operation to follow this* gemm operation - library::Operation const *reduction_op_; - -public: - // - // Methods - // - - /// Ctor - BlockScaledGemmOperationProfiler(Options const &options); - - /// Destructor - virtual ~BlockScaledGemmOperationProfiler(); - - GemmProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Update workspace configuration according to flexible user setups - void update_workspace_( - GemmWorkspace &gemm_workspace, - gemm::GemmCoord const &problem_shape, - std::array const &leading_dim, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - cutlass::library::RasterOrder const &raster_order, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Update performance result configuration according to flexible user setups - void update_result_( - PerformanceResult &result, - library::BlockScaledGemmDescription const &operation_desc, - ProblemSpace const &problem_space, - gemm::GemmCoord const &problem_shape, - cutlass::library::RasterOrder const &raster_order, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::BlockScaledGemmDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against host and device references - bool verify_with_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem, - cutlass::library::NumericTypeID element_A, - cutlass::library::NumericTypeID element_B); - - /// Method to profile a CUTLASS Operation - Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - /// Initialize reduction problem dimensions and library::Operation - bool initialize_reduction_configuration_( - library::Operation const *operation, - ProblemSpace::Problem const &problem); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h deleted file mode 100644 index c110de278cac640c1cedd8dd29d1b8ac09de81ef..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h +++ /dev/null @@ -1,305 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Blockscale Gemm Profiler -*/ - - - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "reduction_operation_profiler.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class BlockwiseGemmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct GemmProblem { - - cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; - - int64_t m{16}; - int64_t n{16}; - int64_t k{16}; - - int64_t sf_vec_m{0}; - int64_t sf_vec_n{0}; - int64_t sf_vec_k{0}; - - int cluster_m{1}; - int cluster_n{1}; - int cluster_k{1}; - int cluster_m_fallback{1}; - int cluster_n_fallback{1}; - int cluster_k_fallback{1}; - - - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - std::vector alpha; - std::vector beta; - - cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; - int split_k_slices{1}; - int batch_count{1}; - - cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; - int swizzle_size{1}; - - /// For profiling purposes - std::vector problem_sizes; - std::vector> leading_dims; - std::vector> preferred_clusters; - std::vector> fallback_clusters; - std::vector raster_orders; - std::vector swizzle_sizes; - - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; - cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; - - - // gemm with parallel interleaved reduction - // gemm epilogue (alpha, beta) = (1.0, 0.0) - // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) - std::vector alpha_one; - std::vector beta_zero; - - bool use_pdl{false}; - // - // Methods - // - - /// Parses the problem - Status parse( - library::BlockwiseGemmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - int64_t bytes_with_problem_shape( - library::BlockwiseGemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - int64_t flops_with_problem_shape( - library::BlockwiseGemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - /// Total number of bytes loaded - int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::BlockwiseGemmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct GemmWorkspace { - - DeviceAllocation *A{nullptr}; - DeviceAllocation *SFA{nullptr}; - DeviceAllocation *B{nullptr}; - DeviceAllocation *SFB{nullptr}; - DeviceAllocation *C{nullptr}; - DeviceAllocation *Computed{nullptr}; - DeviceAllocation *Reference{nullptr}; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - int problem_count{1}; - - library::GemmUniversalConfiguration configuration; - library::BlockwiseGemmArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - /// Library configuration and arguments for reduction operator - library::ReductionConfiguration reduction_configuration; - library::ReductionArguments reduction_arguments; - - /// Buffer used for the cutlass reduction operations' host workspace - std::vector reduction_host_workspace; - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - GemmProblem problem_; - - /// Device memory allocations - GemmWorkspace gemm_workspace_; - - /// CUTLASS parallel reduction operation to follow this* gemm operation - library::Operation const *reduction_op_; - -public: - // - // Methods - // - - /// Ctor - BlockwiseGemmOperationProfiler(Options const &options); - - /// Destructor - virtual ~BlockwiseGemmOperationProfiler(); - - GemmProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::BlockwiseGemmDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against host and device references - bool verify_with_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem, - cutlass::library::NumericTypeID element_A, - cutlass::library::NumericTypeID element_B); - - /// Method to profile a CUTLASS Operation - Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - /// Initialize reduction problem dimensions and library::Operation - bool initialize_reduction_configuration_( - library::Operation const *operation, - ProblemSpace::Problem const &problem); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h deleted file mode 100644 index 683465f50cda19c8d505f2e66bcb60173d7e942d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h +++ /dev/null @@ -1,495 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines profiling functionality for convolution - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/handle.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/singleton.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "reduction_operation_profiler.h" -#if CUTLASS_ENABLE_CUDNN -#include "cudnn_helpers.h" -#endif //#if CUTLASS_ENABLE_CUDNN -#include "debug.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class Conv2dOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct Conv2dProblem { - - int64_t n, h, w, c, p, q, k, r, s; - int64_t groups; - int64_t pad_h, pad_w; - int64_t stride_h, stride_w; - int64_t dilation_h, dilation_w; - - std::vector alpha; - std::vector beta; - - library::SplitKMode split_k_mode; - int64_t split_k_slices; - - library::ConvModeID conv_mode; - - library::Provider eq_gemm_provider; - - // convolution with parallel interleaved reduction - // convolution epilogue (alpha, beta) = (1.0, 0.0) - // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) - std::vector alpha_one; - std::vector beta_zero; - - // - // Methods - // - - /// Total number of bytes loaded - int64_t bytes(library::ConvDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::ConvDescription const &operation_desc) const; - - void set_default_output_size() { - p = ((h + pad_h - r * dilation_h) / stride_h) + 1; - q = ((w + pad_w - s * dilation_w) / stride_w) + 1; - } - - // Returns equivalent gemm problem size for convolution - cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); - case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); - case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor A - std::vector extent_a(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; - case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; - case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor B - std::vector extent_b(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; - case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; - case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor C - std::vector extent_c(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; - case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; - case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix A - library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm - case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm - case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix B - library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm - case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm - case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix C - library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - // Gemm operator assumes column-major output - case library::ConvKind::kFprop: - case library::ConvKind::kDgrad: - case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix A - int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix B - int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix C - int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: - case library::ConvKind::kDgrad: - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - }; - - /// Workspace used - struct Conv2dWorkspace { - - /// Conv device allocations - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *reordered_B; - DeviceAllocation *C; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - /// Library configuration and arguments for convolution operator - library::Conv2dConfiguration configuration; - library::ConvArguments arguments; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - int problem_count; - - /// Buffer used for the cutlass conv2d operations' host workspace - std::vector host_workspace; - - /// Buffer used for the cutlass operations' device workspace - DeviceAllocation device_workspace; - - /// Library configuration and arguments for reduction operator - library::ReductionConfiguration reduction_configuration; - library::ReductionArguments reduction_arguments; - - /// Buffer used for the cutlass reduction operations' host workspace - std::vector reduction_host_workspace; - - /// Host data buffers for host reference operation - /// host buffer for tensor - std::vector host_tensor_a; - - /// host buffer for tensor b - std::vector host_tensor_b; - - /// host buffer for tensor c - std::vector host_tensor_c; - - // - // Methods - // - - Conv2dWorkspace() - : A(nullptr), - B(nullptr), - reordered_B(nullptr), - C(nullptr), - Computed(nullptr), - Reference(nullptr) {} - - // Set stride vector for tensor activations, filters, output - void set_stride_vector(Conv2dProblem const &problem, - library::ConvKind const &conv_kind, - library::LayoutTypeID const &layout_a, - library::LayoutTypeID const &layout_b, - library::LayoutTypeID const &layout_c) { - std::vector stride_activations; - std::vector stride_filters; - std::vector stride_output; - - // Strides for interleaved fprop - if (conv_kind == library::ConvKind::kFprop && - ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && - layout_b == library::LayoutTypeID::kTensorC32RSK32 && - layout_c == library::LayoutTypeID::kTensorNC32HW32) || - (layout_a == library::LayoutTypeID::kTensorNC64HW64 && - layout_b == library::LayoutTypeID::kTensorC64RSK64 && - layout_c == library::LayoutTypeID::kTensorNC64HW64))) { - int interleave = - (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; - - stride_activations.push_back(int(problem.w) * interleave); - stride_activations.push_back(int(problem.w) * int(problem.h) * - interleave); - stride_activations.push_back(int(problem.h) * int(problem.w) * - int(problem.c)); - - stride_filters.push_back(int(problem.k) * interleave); - stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); - stride_filters.push_back(int(problem.k) * int(problem.s) * - int(problem.r) * interleave); - - stride_output.push_back(int(problem.q) * interleave); - stride_output.push_back(int(problem.q) * int(problem.p) * interleave); - stride_output.push_back(int(problem.q) * int(problem.p) * - int(problem.k)); - } else { - // Strides for the rest cases - stride_activations.push_back(int(problem.c)); - stride_activations.push_back(int(problem.w) * int(problem.c)); - stride_activations.push_back(int(problem.h) * int(problem.w) * - int(problem.c)); - - stride_filters.push_back(int(problem.c / problem.groups)); - stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); - stride_filters.push_back(int(problem.r) * int(problem.s) * - int(problem.c / problem.groups)); - - stride_output.push_back(int(problem.k)); - stride_output.push_back(int(problem.q) * int(problem.k)); - stride_output.push_back(int(problem.q) * int(problem.p) * - int(problem.k)); - } - - switch (conv_kind) { - case library::ConvKind::kFprop: - configuration.stride_a = stride_activations; - configuration.stride_b = stride_filters; - configuration.stride_c = stride_output; - - break; - case library::ConvKind::kDgrad: - configuration.stride_a = stride_output; - configuration.stride_b = stride_filters; - configuration.stride_c = stride_activations; - - break; - case library::ConvKind::kWgrad: - configuration.stride_a = stride_output; - configuration.stride_b = stride_activations; - configuration.stride_c = stride_filters; - - break; - default: - throw std::runtime_error( - "Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - }; - -protected: - - // - // Data members - // - - /// CONV problem obtained from problem space - Conv2dProblem problem_; - - /// Device memory allocations - Conv2dWorkspace conv_workspace_; - - /// CUTLASS parallel reduction operation to follow this* conv2d operation - library::Operation const *reduction_op_; - -public: - // - // Methods - // - - /// Ctor - Conv2dOperationProfiler(Options const &options); - - /// Destructor - virtual ~Conv2dOperationProfiler(); - - Conv2dProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - /// Method to profile an initialized CUTLASS operation - virtual Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - - /// Initialize reduction problem dimensions and library::Operation - bool initialize_reduction_configuration_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::ConvDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against host reference - bool verify_with_host_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against device reference - bool verify_with_device_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -#if CUTLASS_ENABLE_CUDNN - - /// Verifies CUTLASS against cudnn reference - bool verify_with_cudnn_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -#endif //#if CUTLASS_ENABLE_CUDNN - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h deleted file mode 100644 index ac4abdef238b00f216053419620a60dfccfd5316..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h +++ /dev/null @@ -1,449 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines profiling functionality for convolution - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/handle.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/singleton.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "reduction_operation_profiler.h" -#if CUTLASS_ENABLE_CUDNN -#include "cudnn_helpers.h" -#endif //#if CUTLASS_ENABLE_CUDNN -#include "debug.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class Conv3dOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct Conv3dProblem { - - int64_t n, d, h, w, c, z, p, q, k, t, r, s; - int64_t pad_d, pad_h, pad_w; - int64_t stride_d, stride_h, stride_w; - int64_t dilation_d, dilation_h, dilation_w; - - std::vector alpha; - std::vector beta; - - library::SplitKMode split_k_mode; - int64_t split_k_slices; - - library::ConvModeID conv_mode; - - library::Provider eq_gemm_provider; - - // convolution with parallel interleaved reduction - // convolution epilogue (alpha, beta) = (1.0, 0.0) - // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) - std::vector alpha_one; - std::vector beta_zero; - - // - // Methods - // - - /// Total number of bytes loaded - int64_t bytes(library::ConvDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::ConvDescription const &operation_desc) const; - - /// Infers output size from the input size, padding, stride, and dilation - void set_default_output_size() { - z = ((d + pad_d - t * dilation_d) / stride_d) + 1; - p = ((h + pad_h - r * dilation_h) / stride_h) + 1; - q = ((w + pad_w - s * dilation_w) / stride_w) + 1; - } - - // Returns equivalent gemm problem size for convolution - cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); - case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); - case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor A - std::vector extent_a(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; - case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; - case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor B - std::vector extent_b(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; - case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; - case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns extent for tensor C - std::vector extent_c(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; - case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; - case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix A - library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm - case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm - case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix B - library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm - case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm - case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns layout for equivalent gemm matrix C - library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - // Gemm operator assumes column-major output - case library::ConvKind::kFprop: - case library::ConvKind::kDgrad: - case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix A - int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix B - int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); - case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns leading dimension for equivalent gemm matrix C - int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { - - switch (conv_kind) { - case library::ConvKind::kFprop: - case library::ConvKind::kDgrad: - case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - }; - - /// Workspace used - struct Conv2dWorkspace { - - /// Conv device allocations - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *C; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - /// Library configuration and arguments for convolution operator - library::Conv3dConfiguration configuration; - library::ConvArguments arguments; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - int problem_count; - - /// Buffer used for the cutlass conv2d operations' host workspace - std::vector host_workspace; - - /// Buffer used for the cutlass operations' device workspace - DeviceAllocation device_workspace; - - /// Library configuration and arguments for reduction operator - library::ReductionConfiguration reduction_configuration; - library::ReductionArguments reduction_arguments; - - /// Buffer used for the cutlass reduction operations' host workspace - std::vector reduction_host_workspace; - - /// Host data buffers for host reference operation - /// host buffer for tensor - std::vector host_tensor_a; - - /// host buffer for tensor b - std::vector host_tensor_b; - - /// host buffer for tensor c - std::vector host_tensor_c; - - - // - // Methods - // - - Conv2dWorkspace(): - A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } - - // Returns stride vector for tensor A - std::vector stride_a(library::ConvKind const &conv_kind) { - return { - configuration.layout_a(conv_kind).stride()[0], - configuration.layout_a(conv_kind).stride()[1], - configuration.layout_a(conv_kind).stride()[2], - configuration.layout_a(conv_kind).stride()[3] - }; - } - - // Returns stride vector for tensor B - std::vector stride_b(library::ConvKind const &conv_kind) { - - return { - configuration.layout_b(conv_kind).stride()[0], - configuration.layout_b(conv_kind).stride()[1], - configuration.layout_b(conv_kind).stride()[2], - configuration.layout_b(conv_kind).stride()[3] - }; - } - - // Returns stride vector for tensor C - std::vector stride_c(library::ConvKind const &conv_kind) { - - return { - configuration.layout_c(conv_kind).stride()[0], - configuration.layout_c(conv_kind).stride()[1], - configuration.layout_c(conv_kind).stride()[2], - configuration.layout_c(conv_kind).stride()[3] - }; - } - }; - -protected: - - // - // Data members - // - - /// CONV problem obtained from problem space - Conv3dProblem problem_; - - /// Device memory allocations - Conv2dWorkspace conv_workspace_; - - /// CUTLASS parallel reduction operation to follow this* conv2d operation - library::Operation const *reduction_op_; - -public: - // - // Methods - // - - /// Ctor - Conv3dOperationProfiler(Options const &options); - - /// Destructor - virtual ~Conv3dOperationProfiler(); - - Conv3dProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Updates the arguments structure for the CUTLASS operator based on - /// the problem index. - void set_cutlass_operator_arguments_(int problem_idx = 0); - - /// Method to profile an initialized CUTLASS operation - virtual Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - /// Initialize reduction problem dimensions and library::Operation - bool initialize_reduction_configuration_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::ConvDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against host reference - bool verify_with_host_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against device reference - bool verify_with_device_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -#if CUTLASS_ENABLE_CUDNN - - /// Verifies CUTLASS against cudnn reference - bool verify_with_cudnn_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -#endif //#if CUTLASS_ENABLE_CUDNN - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h deleted file mode 100644 index 873ba1abe03c05df29edc032ea3f1ffd2f19c3ee..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h +++ /dev/null @@ -1,456 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Helper functions for mapping CUTLASS concepts to cuBLAS. -*/ - -#pragma once - -#if CUTLASS_ENABLE_CUBLAS -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/blas3.h" - -#include "options.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Converts a cuBLAS status to cutlass::Status -Status get_cutlass_status(cublasStatus_t cublas); - -/// Converts a cuBLAS status to cutlass::profiler::Disposition -Disposition get_cutlass_disposition(cublasStatus_t cublas_status); - -/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -bool get_cublas_transpose_operation( - cublasOperation_t &operation, - library::LayoutTypeID layout, - library::ComplexTransform transform = library::ComplexTransform::kNone); - -/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration -bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); - -/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class -cublasGemmAlgo_t get_cublas_gemm_algo( - int cta_m, - int cta_n, - int cta_k, - library::OpcodeClassID opcode_class); - -/// Returns a status if cuBLAS can satisfy a particular GEMM description -Status cublas_satisfies(library::GemmDescription const &desc); - -/// Returns a status if cuBLAS can satisfy a particular RankK description -Status cublas_satisfies(library::RankKDescription const &desc); - -/// Returns a status if cuBLAS can satisfy a particular TRMM description -Status cublas_satisfies(library::TrmmDescription const &desc); - -/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description -Status cublas_satisfies(library::SymmDescription const &desc); - -/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and -/// to destroy cublasHandle_t on CublasCreate object destruction. -/// Additionally, it provides implicit cast from CublasCreate's object to cublasHandle_t's object -class CublasCreate { -private: - cublasHandle_t handle; - cublasStatus_t status; - -public: - CublasCreate() { - status = cublasCreate(&handle); - } - - ~CublasCreate() { - cublasDestroy(handle); - } - - /// Implicit cast CublasCreate object to cublasHandle_t - operator cublasHandle_t() const { return handle; } - - /// returns cublasStatus_t for handle creation - cublasStatus_t get_cublas_create_status() { return status; } -}; - -/// This is a helper class to create cublasLtHandle_t automatically on CublasLtCreate object creation and -/// to destroy cublasLtHandle_t on CublasLtCreate object destruction. -/// Additionally, it provides implicit cast from CublasLtCreate's object to cublasLtHandle_t's object -class CublasLtCreate { -private: - cublasLtHandle_t handle; - cublasStatus_t status; - -public: - CublasLtCreate() { - status = cublasLtCreate(&handle); - } - - ~CublasLtCreate() { - cublasLtDestroy(handle); - } - - /// Implicit cast CublasLtCreate object to cublasLtHandle_t - operator cublasLtHandle_t() const { return handle; } - - /// returns cublasLtStatus_t for handle creation - cublasStatus_t get_cublaslt_create_status() { return status; } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Selects one or more cuBLAS algorithms. -static void select_cublas_algorithms( - std::vector &algorithms, - Options const &options, - library::GemmDescription const &op_desc) { - - library::OpcodeClassID const & opcode_class = - op_desc.tile_description.math_instruction.opcode_class; - - switch (options.library.algorithm_mode) { - case AlgorithmMode::kMatching: - { - algorithms.push_back(get_cublas_gemm_algo( - op_desc.tile_description.threadblock_shape.m(), - op_desc.tile_description.threadblock_shape.n(), - op_desc.tile_description.threadblock_shape.k(), - opcode_class)); - break; - } - - case AlgorithmMode::kBest: - { - // Choose first enumerated mode. If none are enumerated, choose based on opcode class - // and evaluate all of them. - - if (options.library.algorithms.empty()) { - // Enumerate all algorithms - if (opcode_class == library::OpcodeClassID::kSimt) { - - for (int algo = CUBLAS_GEMM_DEFAULT; - algo <= CUBLAS_GEMM_ALGO23; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } - else { - - for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } - } - else { - // Use the listed algorithms - algorithms.reserve(options.library.algorithms.size()); - - for (int algo : options.library.algorithms) { - algorithms.push_back(reinterpret_cast(algo)); - } - } - - break; - } - - case AlgorithmMode::kDefault: - { - - // Use the library's default algorithm - algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? - CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - break; - } - default: - { - break; - } - } -} - -/// Dispatcher to cublasGemmEx() -struct cublasGemmExDispatcher { - - // - // Data members - // - library::GemmUniversalConfiguration configuration; - library::GemmUniversalArguments arguments; - - // cublas-specific data structures to fill cublas API call arguments - cublasOperation_t trans_A; - cublasOperation_t trans_B; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_C; - cudaDataType_t compute_data_type; - -#if (__CUDACC_VER_MAJOR__ >= 11) - cublasComputeType_t compute_type; -#endif - - cublasGemmAlgo_t algo; - Status status; - - // - // Methods - // - - cublasGemmExDispatcher( - library::GemmDescription const &op_desc, - library::GemmUniversalConfiguration configuration_, - library::GemmUniversalArguments arguments_, - cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT - ); - - /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasHandle_t handle); -}; - -/// Dispatcher to cublaslt kernels -// -struct cublasLtGemmExDispatcher { - - // - // Data members - // - library::GemmDescription const &op_desc; - library::GemmUniversalConfiguration configuration; - library::GemmUniversalArguments arguments; - - // cublas-specific data structures to fill cublas API call arguments - cublasOperation_t trans_A; - cublasOperation_t trans_B; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_C; - cudaDataType_t compute_data_type = CUDA_R_32F; - - //cublasLt-specific data structures - cublasLtMatmulDesc_t operationDesc = NULL; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; - cublasLtMatmulPreference_t preference = NULL; - - //is set by call to get_cublaslt_algo() - cublasLtMatmulHeuristicResult_t heuristicResult_; - void *workspace = nullptr; - - Status status; - -#if (__CUDACC_VER_MAJOR__ >= 11) - cublasComputeType_t compute_type; -#endif - - // - // Methods - // - - cublasLtGemmExDispatcher( - library::GemmDescription const &op_desc, - library::GemmUniversalConfiguration configuration_, - library::GemmUniversalArguments arguments_ - ); - - /// Initialize the cublasLt variables - void initialize_cublaslt(); - - - /// Runs auto-tuning for the cublas heuristics - bool get_cublaslt_algo(cublasLtHandle_t handle, - AlgorithmMode algorithm_mode - ); - - /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); - - ~cublasLtGemmExDispatcher(){ - - // descriptors are no longer needed as all GPU work was already enqueued - if (preference) cublasLtMatmulPreferenceDestroy(preference); - if (Ddesc) cublasLtMatrixLayoutDestroy(Ddesc); - if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc); - if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc); - if (Adesc) cublasLtMatrixLayoutDestroy(Adesc); - if (operationDesc) cublasLtMatmulDescDestroy(operationDesc); - - if (workspace) { - cudaFree(workspace); - } - - } - -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dispatcher to cublas rank k update kernels -struct cublasRankKDispatcher { - - // - // Data members - // - library::RankKConfiguration configuration; - library::RankKArguments arguments; - - // cublas-specific data structures to fill cublas API call arguments - cublasOperation_t trans_A; - cublasFillMode_t uplo; - cudaDataType_t data_type_A; - cudaDataType_t data_type_C; - cudaDataType_t compute_data_type; - -#if (__CUDACC_VER_MAJOR__ >= 11) - cublasComputeType_t compute_type; -#endif - - int num_ranks; //(rank-k or rank-2k) - BlasMode blas_mode; //(symmetric or hermitian) - Status status; - - // - // Methods - // - - cublasRankKDispatcher( - library::RankKDescription const &op_desc, - library::RankKConfiguration configuration_, - library::RankKArguments arguments_ - ); - - /// Executes RankK using these arguments - cublasStatus_t operator()(cublasHandle_t handle); -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dispatcher to cublasTrmm() -struct cublasTrmmDispatcher { - - // - // Data members - // - library::TrmmConfiguration configuration; - library::TrmmArguments arguments; - - // cublas-specific data structures to fill cublas API call arguments - cublasOperation_t trans_A; - cublasSideMode_t side; - cublasFillMode_t uplo; - cublasDiagType_t diag; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_D; - cudaDataType_t compute_data_type; - -#if (__CUDACC_VER_MAJOR__ >= 11) - cublasComputeType_t compute_type; -#endif - - Status status; - - // - // Methods - // - - cublasTrmmDispatcher( - library::TrmmDescription const &op_desc, - library::TrmmConfiguration configuration_, - library::TrmmArguments arguments_ - ); - - /// Executes TRMM using these arguments - cublasStatus_t operator()(cublasHandle_t handle); -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Dispatcher to cublas symm/hemm update kernels -struct cublasSymmDispatcher { - - // - // Data members - // - library::SymmConfiguration configuration; - library::SymmArguments arguments; - - // cublas-specific data structures to fill cublas API call arguments - cublasSideMode_t side; - cublasFillMode_t uplo; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_C; - cudaDataType_t compute_data_type; - -#if (__CUDACC_VER_MAJOR__ >= 11) - cublasComputeType_t compute_type; -#endif - - BlasMode blas_mode; //(symmetric or hermitian) - Status status; - - // - // Methods - // - - cublasSymmDispatcher( - library::SymmDescription const &op_desc, - library::SymmConfiguration configuration_, - library::SymmArguments arguments_ - ); - - /// Executes Symm using these arguments - cublasStatus_t operator()(cublasHandle_t handle); -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -} // namespace profiler -} // namespace cutlass - - -#endif // #if CUTLASS_ENABLE_CUBLAS diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h deleted file mode 100644 index 7ce9eea5a883fa4c5732f5d8aec120a99064bac0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h +++ /dev/null @@ -1,590 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Helper functions for mapping CUTLASS concepts to cuDNN. - -*/ - -#pragma once -#if CUTLASS_ENABLE_CUDNN -#include -#include -#include -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/library/library.h" -#include "enumerated_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Converts a cuDNN status to cutlass::Status -Status get_cutlass_status(cudnnStatus_t cudnn_status); - -/// Converts a cuDNN status to cutlass::profiler::Disposition -Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); - -/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception -Status checkCudnnErr(cudnnStatus_t cudnn_status); - -/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration -bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); - -/// Maps a CUTLASS layout type to a cuDNN data type enumeration -bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); - -/// Maps a CUTLASS numeric type to a cuDNN data type enumeration -bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); - -/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type -bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); - -/// Returns a status if cudnn can satisfy a particular Conv2d description -Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); - -/// Returns a status if cudnn can satisfy a particular Conv3d description -Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); - -/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) -float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); - - -/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and -/// to destroy cudnnHandle_t on CudnnCreate object destruction. -/// Additionally, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object -class CudnnCreate { -private: - cudnnHandle_t handle; - cudnnStatus_t status; - -public: - CudnnCreate() { - status = cudnnCreate(&handle); - } - - ~CudnnCreate() { - cudnnDestroy(handle); - } - - /// Implicit cast CudnnCreate object to cudnnHandle_t - operator cudnnHandle_t() const { return handle; } - - /// returns cudnnStatus_t for handle creation - cudnnStatus_t get_cudnn_create_status() { return status; } -}; - - -namespace detail { - -/// Dispatcher to cudnn convolution operators -struct cudnnConvDispatcher { - - // - // Data members - // - //library::Conv2dConfiguration configuration; - library::ConvArguments arguments; - library::ConvKind conv_kind; - - // cudnn-specific data structures to fill cudnn API call arguments - // cudnn activation, filter, and output descriptors - cudnnTensorDescriptor_t activation_desc; - cudnnFilterDescriptor_t filter_desc; - cudnnTensorDescriptor_t output_desc; - cudnnConvolutionDescriptor_t conv_desc; - - // cudnn datatypes - cudnnDataType_t data_type_activation; - cudnnDataType_t data_type_filter; - cudnnDataType_t data_type_output; - - // cudnn layouts - cudnnTensorFormat_t layout_activation; - cudnnTensorFormat_t layout_filter; - cudnnTensorFormat_t layout_output; - - // cudnn convolution mode - cudnnConvolutionMode_t conv_mode; - - // cudnn math type (tensorop, tensorop with conversion, simt) - cudnnMathType_t math_type; - - // cudnn compute data type - cudnnDataType_t compute_type; - - // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) - float alpha; - float beta; - - // cudnn workspace - size_t workspace_size_in_bytes = 0; - cutlass::device_memory::allocation workspace; - - // select cudnn's implicit gemm precomputed algorithm with tensor operations - static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - - Status status; - - // - // Methods - // - - // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfiguration - - // ctor for conv2d - cudnnConvDispatcher( - library::ConvDescription const &op_desc, - library::Conv2dConfiguration configuration, - library::ConvArguments arguments_, - cudnnHandle_t handle - ): - //configuration(configuration_), - arguments(arguments_), - conv_kind(op_desc.conv_kind), - status(Status::kSuccess) { - - bool good = true; - - // Get cudnn datatype, layout, and convolution mode from library::ConvDescription - good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); - good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); - good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); - good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); - good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); - good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); - good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); - // Get cudnn mathtype (cudnnMathType_t) - good = (good && get_cudnn_mathtype(math_type, op_desc)); - good = (good && get_cudnn_datatype( - compute_type, - op_desc.tile_description.math_instruction.element_accumulator)); - // Check cutlass Conv2d description has equivalent operator in cudnn - if (!good) { - status = Status::kErrorNotSupported; - return; - } - // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) - alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); - beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); - - // Create convolution descriptor object - status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); - - // Configure convolution operator - std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; - std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; - std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; - - status = get_cutlass_status( - cudnnSetConvolutionNdDescriptor( - conv_desc, - op_desc.conv_dim, - padding.data(), - stride.data(), - dilation.data(), - conv_mode, - compute_type - )); - - // Set groups - status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); - - // Create activation, filter, and output descriptor objects - status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); - status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); - status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); - - // Set activation, filter, and output descriptor - status = get_cutlass_status( - cudnnSetTensor4dDescriptor( - activation_desc, - layout_activation, - data_type_activation, - configuration.problem_size.N, - configuration.problem_size.C, - configuration.problem_size.H, - configuration.problem_size.W - )); - - status = get_cutlass_status( - cudnnSetFilter4dDescriptor( - filter_desc, - data_type_filter, - layout_filter, - configuration.problem_size.K, - configuration.problem_size.C / configuration.problem_size.groups, - configuration.problem_size.R, - configuration.problem_size.S - )); - - status = get_cutlass_status( - cudnnSetTensor4dDescriptor( - output_desc, - layout_output, - data_type_output, - configuration.problem_size.N, - configuration.problem_size.K, - configuration.problem_size.P, - configuration.problem_size.Q - )); - - // Set math instruction to tensor op - status = get_cutlass_status( - cudnnSetConvolutionMathType(conv_desc, math_type)); - - // Initialize workspace - switch (conv_kind) { - case library::ConvKind::kFprop: - status = get_cutlass_status( - cudnnGetConvolutionForwardWorkspaceSize( - handle, - activation_desc, - filter_desc, - conv_desc, - output_desc, - fprop_algo, - &workspace_size_in_bytes - )); break; - case library::ConvKind::kDgrad: - status = get_cutlass_status( - cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, - filter_desc, - output_desc, - conv_desc, - activation_desc, - dgrad_algo, - &workspace_size_in_bytes - )); break; - case library::ConvKind::kWgrad: - status = get_cutlass_status( - cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, - activation_desc, - output_desc, - conv_desc, - filter_desc, - wgrad_algo, - &workspace_size_in_bytes - )); break; - - } - - workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); - } - - - // ctor for conv3d - cudnnConvDispatcher( - library::ConvDescription const &op_desc, - library::Conv3dConfiguration configuration, - library::ConvArguments arguments_, - cudnnHandle_t handle - ): - //configuration(configuration_), - arguments(arguments_), - conv_kind(op_desc.conv_kind), - status(Status::kSuccess) { - - bool good = true; - - // Get cudnn datatype, layout, and convolution mode from library::ConvDescription - good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); - good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); - good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); - - good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); - good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); - good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); - - good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); - - // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) - alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); - beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); - - good = (good && get_cudnn_datatype( - compute_type, - op_desc.tile_description.math_instruction.element_accumulator)); - - // Check cutlass Conv2d description has equivalent operator in cudnn - if (!good) { - status = Status::kErrorNotSupported; - } - - // Create convolution descriptor object - status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); - - // Configure convolution operator - std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; - std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; - std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; - - status = get_cutlass_status( - cudnnSetConvolutionNdDescriptor( - conv_desc, - op_desc.conv_dim, - padding.data(), - stride.data(), - dilation.data(), - conv_mode, - compute_type - )); - - // Set groups - status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); - - // Create activation, filter, and output descriptor objects - status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); - status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); - status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); - - // Set activation descriptor - std::vector activation_extent { - configuration.problem_size.N, - configuration.problem_size.C, - configuration.problem_size.D, - configuration.problem_size.H, - configuration.problem_size.W - }; - - std::vector activation_stride { - configuration.layout_activations.stride()[3], - 1, - configuration.layout_activations.stride()[2], - configuration.layout_activations.stride()[1], - configuration.layout_activations.stride()[0] - }; - - status = get_cutlass_status( - cudnnSetTensorNdDescriptor( - activation_desc, - data_type_activation, - op_desc.conv_dim + 2, - activation_extent.data(), - activation_stride.data() - )); - - // Set filter descriptor - std::vector filter_extent { - configuration.problem_size.K, - configuration.problem_size.C, - configuration.problem_size.T, - configuration.problem_size.R, - configuration.problem_size.S - }; - - std::vector filter_stride { - configuration.layout_filters.stride()[3], - 1, - configuration.layout_filters.stride()[2], - configuration.layout_filters.stride()[1], - configuration.layout_filters.stride()[0] - }; - - status = get_cutlass_status( - cudnnSetFilterNdDescriptor( - filter_desc, - data_type_filter, - layout_filter, - op_desc.conv_dim + 2, - filter_extent.data() - )); - - - // Set output descriptor - std::vector output_extent { - configuration.problem_size.N, - configuration.problem_size.K, - configuration.problem_size.Z, - configuration.problem_size.P, - configuration.problem_size.Q - }; - - std::vector output_stride { - configuration.layout_output.stride()[3], - 1, - configuration.layout_output.stride()[2], - configuration.layout_output.stride()[1], - configuration.layout_output.stride()[0] - }; - - status = get_cutlass_status( - cudnnSetTensorNdDescriptor( - output_desc, - data_type_output, - op_desc.conv_dim + 2, - output_extent.data(), - output_stride.data() - )); - - // Set math instruction to tensor op - status = get_cutlass_status( - cudnnSetConvolutionMathType(conv_desc, math_type)); - - // Initialize workspace - switch (conv_kind) { - case library::ConvKind::kFprop: - status = get_cutlass_status( - cudnnGetConvolutionForwardWorkspaceSize( - handle, - activation_desc, - filter_desc, - conv_desc, - output_desc, - fprop_algo, - &workspace_size_in_bytes - )); break; - case library::ConvKind::kDgrad: - status = get_cutlass_status( - cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, - filter_desc, - output_desc, - conv_desc, - activation_desc, - dgrad_algo, - &workspace_size_in_bytes - )); break; - case library::ConvKind::kWgrad: - status = get_cutlass_status( - cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, - activation_desc, - output_desc, - conv_desc, - filter_desc, - wgrad_algo, - &workspace_size_in_bytes - )); break; - - } - - workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); - } - - /// Executes Conv2d operator from cudnn library - cudnnStatus_t operator()(cudnnHandle_t handle) { - - switch (conv_kind) { - case library::ConvKind::kFprop: - return cudnnConvolutionForward( - handle, - &alpha, - activation_desc, - activation(), - filter_desc, - filter(), - conv_desc, - fprop_algo, - workspace.get(), - workspace_size_in_bytes, - &beta, - output_desc, - arguments.D - ); - case library::ConvKind::kDgrad: - return cudnnConvolutionBackwardData( - handle, - &alpha, - filter_desc, - filter(), - output_desc, - output(), - conv_desc, - dgrad_algo, - workspace.get(), - workspace_size_in_bytes, - &beta, - activation_desc, - arguments.D - ); - case library::ConvKind::kWgrad: - return cudnnConvolutionBackwardFilter( - handle, - &alpha, - activation_desc, - activation(), - output_desc, - output(), - conv_desc, - wgrad_algo, - workspace.get(), - workspace_size_in_bytes, - &beta, - filter_desc, - arguments.D - ); - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Activation Tensor - void const * activation() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return arguments.A; - case library::ConvKind::kDgrad : return arguments.C; - case library::ConvKind::kWgrad : return arguments.B; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Filter Tensor - void const *filter() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return arguments.B; - case library::ConvKind::kDgrad : return arguments.B; - case library::ConvKind::kWgrad : return arguments.C; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Output Tensor - void const *output() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return arguments.C; - case library::ConvKind::kDgrad : return arguments.A; - case library::ConvKind::kWgrad : return arguments.A; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } -}; - -} // namespace detail -///////////////////////////////////////////////////////////////////////////////////////////////// -#endif //#if CUTLASS_ENABLE_CUDNN -} // namespace profiler -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h deleted file mode 100644 index be82245325cebb147e2c801965a52ece91395cb2..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h +++ /dev/null @@ -1,93 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Execution environment -*/ - -#pragma once -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" -#include "cutlass/library/singleton.h" - -#include "options.h" -#include "operation_profiler.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// CUTLASS Profiler application -class CutlassProfiler { -private: - - // - // Data members - // - - /// Performance testbench options - Options options_; - - /// Entry points for each operation - OperationProfilerVector operation_profilers_; - -private: - - /// Prints usage - void print_usage_(std::ostream &); - - /// Prints usage - void print_options_(std::ostream &); - - /// Enumerates all operations - void enumerate_(); - - /// Profiles all operations - int profile_(); - -public: - - CutlassProfiler(Options const &options); - ~CutlassProfiler(); - - /// Invokes profiling operations - int operator()(); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h deleted file mode 100644 index 98f1fdc3044501e456c927471b30d74b09eafd39..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h +++ /dev/null @@ -1,56 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief -*/ - -#pragma once - -#include - -//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } -//#define report(x) {} - -// Enable/Disable Profiler debug prints -//#define DEBUG_PROFILER - -//RED 31m // profiler prints debug messages in red -//YELLOW 33m // ir prints debug messages in yellow - -#ifndef DEBUG_PROFILER -#define debugprof(...) -#else -#define debugprof(...) do { \ - printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ - printf(__VA_ARGS__); \ - printf("\033[0m\n"); \ - } while (0) -#endif diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h deleted file mode 100644 index 488b635c2ec233e3027303bbf15a34f375a438fd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h +++ /dev/null @@ -1,246 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Execution environment -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/library/library.h" -#include "cutlass/util/distribution.h" - -#include "enumerated_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Device memory allocation -class DeviceAllocation { -private: - - /// Data type of contained elements - library::NumericTypeID type_; - - /// Gets the stride between elements - size_t batch_stride_; - - /// Capacity in elements of device allocation - size_t capacity_; - - /// Pointer to device memory - void *pointer_; - - /// Layout type ID - library::LayoutTypeID layout_; - - /// Stride vector - std::vector stride_; - - /// Extent vector - std::vector extent_; - - /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory - int batch_count_; - - /// Buffer holding TensorRef instance to recently allocated memory - std::vector tensor_ref_buffer_; - - /// The device ID where the allocation is made - int device_; - -public: - // - // Static member functions - // - - /// Determines the number of bytes needed to represent this numeric type - static size_t bytes(library::NumericTypeID type, size_t capacity); - - /// Returns the stride of a packed layout - static std::vector get_packed_layout( - library::LayoutTypeID layout_id, - std::vector const &extent); - - /// returns the capacity needed - static size_t construct_layout( - void *bytes, - library::LayoutTypeID layout_id, - std::vector const &extent, - std::vector &stride); - - /// Returns true if two blocks have exactly the same value - static bool block_compare_equal( - library::NumericTypeID numeric_type, - void const *ptr_A, - void const *ptr_B, - size_t capacity); - - /// Returns true if two blocks have approximately the same value - static bool block_compare_relatively_equal( - library::NumericTypeID numeric_type, - void const *ptr_A, - void const *ptr_B, - size_t capacity, - double epsilon, - double nonzero_floor); - -public: - // - // Methods - // - - DeviceAllocation(); - - DeviceAllocation( - library::NumericTypeID type, - size_t capacity, - int device = -1); - - DeviceAllocation( - library::NumericTypeID type, - library::LayoutTypeID layout_id, - std::vector const &extent, - std::vector const &stride = std::vector(), - int batch_count = 1, - int device = -1); - - ~DeviceAllocation(); - - DeviceAllocation &reset(); - - /// Allocates device memory of a given type and capacity - DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); - - /// Allocates memory for a given layout and tensor - DeviceAllocation &reset( - library::NumericTypeID type, - library::LayoutTypeID layout_id, - std::vector const &extent, - std::vector const &stride = std::vector(), - int batch_count = 1); - - /// Returns a buffer owning the tensor reference - std::vector &tensor_ref() { - return tensor_ref_buffer_; - } - - bool good() const; - - /// Data type of contained elements - library::NumericTypeID type() const; - - /// Pointer to start of device memory allocation - void *data() const; - - /// Pointer to the first element of a batch - void *batch_data(int batch_idx) const; - - /// Gets the layout type - library::LayoutTypeID layout() const; - - /// Gets the stride vector - std::vector const & stride() const; - - /// Gets the extent vector - std::vector const & extent() const; - - /// Gets the number of adjacent tensors in memory - int batch_count() const; - - /// Gets the stride (in units of elements) between items - int64_t batch_stride() const; - - /// Gets the stride (in units of bytes) between items - int64_t batch_stride_bytes() const; - - /// Capacity of allocation in number of elements - size_t capacity() const; - - /// Capacity of allocation in bytes - size_t bytes() const; - - /// Initializes a device allocation to a random distribution using cuRAND - void initialize_random_device(int seed, Distribution dist); - - /// Initializes a host allocation to a random distribution using std::cout - void initialize_random_host(int seed, Distribution dist); - - /// Initializes a device allocation to a sequential distribution - void initialize_sequential_device(Distribution dist); - - /// Initializes a host allocation to a sequential distribution - void initialize_sequential_host(Distribution dist); - - /// Initializes a device allocation to a random distribution using cuRAND - void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); - - /// Initializes a host allocation to a random distribution using std::cout - void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); - - /// Uniformly fills a tensor with a value when provided o.w. zero - void fill_device(double value); - - /// Uniformly fills a host allocation with a value when provided o.w. zero - void fill_host(double value); - - /// Copies from an equivalent-sized tensor in device memory - void copy_from_device(void const *ptr); - - /// Copies from an equivalent-sized tensor in device memory - void copy_from_host(void const *ptr); - - /// Copies from an equivalent-sized tensor in device memory - void copy_to_host(void *ptr); - - /// Writes a tensor to csv - void write_tensor_csv(std::ostream &out); - -private: - /// A wrapper that sets the device, performs malloc, and sets back - cudaError_t malloc(void** ptr, size_t size); -}; - -using DeviceAllocationList = std::list; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h deleted file mode 100644 index 0443b340397426bfafc812c1a4b9179fc6af0de4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h +++ /dev/null @@ -1,136 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief -*/ - -#pragma once - -#include -#include - - -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" - -#include "options.h" -#include "device_allocation.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Collection of allocations on the device -class DeviceContext { -public: - - // - // Type definitions - // - using AllocationMap = std::map; - -private: - // - // Data members - // - - /// Memory allocations that exist (owning) - DeviceAllocationList device_memory_; - - /// Non-owning set of named allocations - AllocationMap allocations_; - -public: - - /// Allocates memory of a given type, capacity (elements), and name - DeviceAllocation *allocate_block( - Options const &options, - std::string const &name, - library::NumericTypeID type, - size_t capacity, - size_t device_index); - - /// Allocates memory of a given type, capacity (elements), and name - DeviceAllocation *allocate_tensor( - Options const &options, - std::string const &name, - library::NumericTypeID type, - library::LayoutTypeID layout_id, - std::vector const &extent, - std::vector const &stride, - int batch_count, - size_t device_index); - - /// Allocates memory of a given type, capacity (elements), and name - DeviceAllocation *allocate_and_initialize_tensor( - Options const &options, - std::string const &name, - library::NumericTypeID type, - library::LayoutTypeID layout_id, - std::vector const &extent, - std::vector const &stride, - int batch_count, - int seed_shift, - size_t device_index); - - /// Allocates memory for sparse meta data - DeviceAllocation *allocate_and_initialize_sparsemeta_tensor( - Options const &options, - std::string const &name, - library::NumericTypeID type, - library::LayoutTypeID layout_id, - library::NumericTypeID type_a, - std::vector const &extent, - std::vector const &stride, - int batch_count, - int seed_shift, - size_t device_index); - - /// Clears named allocations (but does not necessarily free memory) - void clear(); - - /// Frees all device memory allocations - void free(); - - /// Gets the allocation by name - DeviceAllocation &at(std::string const &name); - - size_t size() const; - - AllocationMap::iterator begin(); - AllocationMap::iterator end(); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h deleted file mode 100644 index 897311c228ce76c4e8814ce996929561d44d2465..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h +++ /dev/null @@ -1,169 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Provides several functions for filling tensors with data. -*/ - -#pragma once - -#include -#include -#include -#include -#include "cutlass/library/library.h" - -#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -T from_string(std::string const &); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumerated type describing how the performance testbench evaluates kernels. -enum class ExecutionMode { - kProfile, ///< regular verification and profiling - kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched - kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations - kTrace, ///< executes a single device-side computation with no other kernel launches - kInvalid -}; - -/// Converts a ExecutionMode enumerant to a string -char const *to_string(ExecutionMode mode, bool pretty = false); - -/// Parses a ExecutionMode enumerant from a string -template <> -ExecutionMode from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Library algorithm mode -enum class AlgorithmMode { - kMatching, ///< compare against best matching algorithm - kBest, ///< evaluate all library algorithms and report best - kDefault, ///< use the library's default algorithm option - kInvalid -}; - -/// Converts a ExecutionMode enumerant to a string -char const *to_string(AlgorithmMode mode, bool pretty = false); - -/// Parses a ExecutionMode enumerant from a string -template <> -AlgorithmMode from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Outcome of a performance test -enum class Disposition { - kPassed, - kFailed, // kernel itself reported an error - kNotRun, - kIncorrect, // kernel finished without a detected error, but result does not equal expected result - kNotVerified, - kInvalidProblem, - kNotSupported, - kInvalid -}; - -/// Converts a Disposition enumerant to a string -char const *to_string(Disposition disposition, bool pretty = false); - -/// Parses a Disposition enumerant from a string -template <> -Disposition from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Indicates when to save -enum class SaveWorkspace { - kNever, - kIncorrect, - kAlways, - kInvalid -}; - -/// Converts a SaveWorkspace enumerant to a string -char const *to_string(SaveWorkspace save_option, bool pretty = false); - -/// Parses a SaveWorkspace enumerant from a string -template <> -SaveWorkspace from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Indicates the type of kernel argument -// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric -// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. -// Its c++ equivalent as "type name = initializer" is "u32 m = 32" -// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. -// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" -enum class ArgumentTypeID { - kScalar, - kInteger, - kTensor, - kBatchedTensor, - kStructure, - kEnumerated, - kInvalid -}; - -/// Converts a ArgumentTypeID enumerant to a string -char const *to_string(ArgumentTypeID type, bool pretty = false); - -/// Parses a ArgumentTypeID enumerant from a string -template <> -ArgumentTypeID from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// -// Profiler typedefs -using ProviderVector = std::vector; -using DispositionMap = std::map; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Print vector for the report -template -std::ostream& operator<< (std::ostream& out, const std::vector& v) { - for (size_t i = 0; i < v.size(); ++i) { - out << to_string(v[i], true) << (i + 1u != v.size() ? "," : ""); - } - return out; -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h deleted file mode 100644 index faf317152473cac6dc62ecf8970cd1acfb2c1622..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ /dev/null @@ -1,333 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Gemm Profiler -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "reduction_operation_profiler.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class GemmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct GemmProblem { - - cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; - - /// For profiling purposes - std::vector problem_sizes; - std::vector> leading_dims; - std::vector> preferred_clusters; - std::vector> fallback_clusters; - std::vector raster_orders; - std::vector swizzle_sizes; - - int64_t m{16}; - int64_t n{16}; - int64_t k{16}; - - - int cluster_m{1}; - int cluster_n{1}; - int cluster_k{1}; - int cluster_m_fallback{1}; - int cluster_n_fallback{1}; - int cluster_k_fallback{1}; - - - int64_t lda{0}; - int64_t ldb{0}; - int64_t ldc{0}; - std::vector alpha; - std::vector beta; - - cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; - int split_k_slices{1}; - int batch_count{1}; - - cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; - int swizzle_size{1}; - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; - cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; - - - // gemm with parallel interleaved reduction - // gemm epilogue (alpha, beta) = (1.0, 0.0) - // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) - std::vector alpha_one; - std::vector beta_zero; - - bool use_pdl{false}; - - bool enable_sm90_mixed_dtype_shuffle_test{false}; - - // - // Methods - // - - /// Parses the problem - Status parse( - library::GemmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - int64_t bytes_with_problem_shape( - library::GemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - int64_t flops_with_problem_shape( - library::GemmDescription const &operation_desc, - gemm::GemmCoord const &problem_shape) const; - - /// Total number of bytes loaded - int64_t bytes(library::GemmDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::GemmDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::GemmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct GemmWorkspace { - - DeviceAllocation *A{nullptr}; - DeviceAllocation *B{nullptr}; - DeviceAllocation *C{nullptr}; - DeviceAllocation *Computed{nullptr}; - DeviceAllocation *Reference{nullptr}; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - int problem_count{1}; - - library::GemmUniversalConfiguration configuration; - library::GemmUniversalArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - /// Library configuration and arguments for reduction operator - library::ReductionConfiguration reduction_configuration; - library::ReductionArguments reduction_arguments; - - /// Buffer used for the cutlass reduction operations' host workspace - std::vector reduction_host_workspace; - - /// For mixed input dtype kernels - DeviceAllocation *Scale{nullptr}; // Scale tensor - DeviceAllocation *Zero{nullptr}; // Zero tensor - DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification - DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle - DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 - - cudaStream_t stream; - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - GemmProblem problem_; - - /// Device memory allocations - std::vector gemm_workspace_; - - /// CUTLASS parallel reduction operation to follow this* gemm operation - library::Operation const *reduction_op_; - -public: - // - // Methods - // - - /// Ctor - GemmOperationProfiler(Options const &options); - - /// Destructor - virtual ~GemmOperationProfiler(); - - GemmProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - /// Update workspace configuration according to flexible user setups - void update_workspace_( - GemmWorkspace &gemm_workspace, - gemm::GemmCoord const &problem_shape, - std::array const &leading_dim, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - cutlass::library::RasterOrder const &raster_order, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Update performance result configuration according to flexible user setups - void update_result_( - PerformanceResult &result, - library::GemmDescription const &operation_desc, - ProblemSpace const &problem_space, - gemm::GemmCoord const &problem_shape, - cutlass::library::RasterOrder const &raster_order, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::GemmDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem, - GemmWorkspace &gemm_workspace); - - /// Verifies CUTLASS against host and device references - bool verify_with_reference_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem, - cutlass::library::NumericTypeID element_A, - cutlass::library::NumericTypeID element_B); - - /// Method to profile a CUTLASS Operation - Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - /// Initialize reduction problem dimensions and library::Operation - bool initialize_reduction_configuration_( - library::Operation const *operation, - ProblemSpace::Problem const &problem); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h deleted file mode 100644 index 154045295d6443d930ba53387366f4b8abe408a4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h +++ /dev/null @@ -1,77 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function -*/ - -#pragma once - -#include -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct GpuTimer { - - cudaEvent_t events[2]; - - // - // Methods - // - - GpuTimer(); - - GpuTimer(GpuTimer const&) = delete; - - GpuTimer(GpuTimer &&gpu_timer) noexcept; - - ~GpuTimer(); - - /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags - void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - - /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags - void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - - /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags - void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - - /// Returns the duration in milliseconds - double duration(int iterations = 1) const; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h deleted file mode 100644 index 62d47990584cbb984935a00a267cff15dbb4f4e5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h +++ /dev/null @@ -1,344 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* \file - \brief GroupedGemm Profiler -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" - -// Profiler includes -#include "device_context.h" -#include "operation_profiler.h" -#include "options.h" -#include "performance_result.h" -#include "problem_space.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class GroupedGemmOperationProfiler : public OperationProfiler { -public: - /// Problem structure obtained from problem space - struct GroupedGemmProblem { - - cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; - - std::vector problem_sizes; - std::vector> problem_sizes_3x; - - /// For exploration purposes - std::vector> preferred_clusters; - std::vector> fallback_clusters; - std::vector raster_orders; - std::vector swizzle_sizes; - - int cluster_m{1}; - int cluster_n{1}; - int cluster_k{1}; - int cluster_m_fallback{1}; - int cluster_n_fallback{1}; - int cluster_k_fallback{1}; - - std::vector lda{0}; - std::vector ldb{0}; - std::vector ldc{0}; - - std::vector alpha; - std::vector beta; - - cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; - int swizzle_size{1}; - - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; - cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; - - bool use_pdl{false}; - - /// Parses the problem - Status parse( - library::GroupedGemmDescription const& operation_desc, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - - int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; - int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; - int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; - - /// Total number of bytes loaded - int64_t bytes(library::GroupedGemmDescription const& operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::GroupedGemmDescription const& operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult& result, - library::GroupedGemmDescription const& operation_desc, - ProblemSpace const& problem_space); - }; - - struct BlockScalingWorkspace { - // host vector (per L2 workspace) of device vectors (per group) of device pointers - std::vector SFA_ptr_array_device; - std::vector SFB_ptr_array_device; - std::vector SFC_ptr_array_device; - std::vector SFD_ptr_array_device; - - // host vector (per group) of device tensors - // (where each batch of device allocation is for a L2 workspace) - std::vector SFA_ptr_array_host; - std::vector SFB_ptr_array_host; - std::vector SFC_ptr_array_host; - std::vector SFD_ptr_array_host; - std::vector SFD_reference_ptr_array_host; - - // matrix wide constant, not per-batch or per-group - DeviceAllocation* norm_constant; - }; - - // workspace contains the allocated blocks, arguments just contain the raw - // pointers - struct GroupedGemmWorkspace { - - // host vector (per L2 workspace) of device vectors (per group) of device pointers - std::vector A_ptr_array_device; - std::vector B_ptr_array_device; - std::vector C_ptr_array_device; - std::vector D_ptr_array_device; - std::vector reference_ptr_array_host; - - // host vector (per group) of device tensors - // (where each batch of device allocation is for a L2 workspace) - std::vector A_ptr_array_host; - std::vector B_ptr_array_host; - std::vector C_ptr_array_host; - std::vector D_ptr_array_host; - - /// Number of copies of the problem workspace which are visited sequentially during - /// profiling to avoid camping in the last level cache. - /// *NOT* the number of groups in the grouped GEMM (we use `num_groups` in the profiler) - int problem_count{1}; - - DeviceAllocation* problem_sizes_array_device{nullptr}; - DeviceAllocation* problem_sizes_3x_array_device{nullptr}; - DeviceAllocation* lda_array_device{nullptr}; - DeviceAllocation* ldb_array_device{nullptr}; - DeviceAllocation* ldc_array_device{nullptr}; - DeviceAllocation* ldd_array_device{nullptr}; - - std::optional block_scales; - - library::GemmGroupedConfiguration configuration; - library::GroupedGemmBlockScaledArguments arguments; - - std::vector host_workspace; - DeviceAllocation device_workspace; - - cudaStream_t stream; - }; - -private: - void init_arguments(Options const& options) { - auto& arguments = gemm_workspace_.arguments; - // these get updated in each profiler run to ensure L2 cycling - arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); - arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); - arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); - arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); - - arguments.alpha = problem_.alpha.data(); - arguments.beta = problem_.beta.data(); - arguments.pointer_mode = library::ScalarPointerMode::kHost; - arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); - arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); - arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); - arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); - arguments.problem_sizes = - static_cast(gemm_workspace_.problem_sizes_array_device->data()); - arguments.problem_sizes_3x = static_cast*>( - gemm_workspace_.problem_sizes_3x_array_device->data()); - gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); - gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); - gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; - gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; - - /* Query device SM count to pass onto the kernel as an argument, where needed */ - arguments.sm_count = options.device.get_sm_count(0); - if (is_block_scaled) { - auto& block_scaled_ws = gemm_workspace_.block_scales.value(); - arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); - arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); - arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data(); - arguments.norm_constant = block_scaled_ws.norm_constant->data(); - } - else if (is_blockwise) { - auto& block_scaled_ws = gemm_workspace_.block_scales.value(); - arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); - arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); - } - } - -protected: - /// GEMM problem obtained from problem space - GroupedGemmProblem problem_; - - /// Device memory allocations - GroupedGemmWorkspace gemm_workspace_; - - bool is_block_scaled{false}; - bool is_blockwise{false}; - -public: - GroupedGemmOperationProfiler(Options const& options); - - virtual ~GroupedGemmOperationProfiler(); - - GroupedGemmProblem const& problem() const { return problem_; } - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream& out) const; - - /// Prints examples - virtual void print_examples(std::ostream& out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const& options, - PerformanceReport& report, - DeviceContext& device_context, - library::Operation const* operation, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const& options, - PerformanceReport& report, - DeviceContext& device_context, - library::Operation const* operation, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const& options, - PerformanceReport& report, - DeviceContext& device_context, - library::Operation const* operation, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - - /// Measures performance results - virtual bool profile( - Options const& options, - PerformanceReport& report, - DeviceContext& device_context, - library::Operation const* operation, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - -protected: - /// Initializes the performance result - void initialize_result_( - PerformanceResult& result, - Options const& options, - library::GroupedGemmDescription const& operation_desc, - ProblemSpace const& problem_space); - - /// Update workspace configuration according to flexible user setups - void update_workspace_( - GroupedGemmWorkspace &gemm_workspace, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - cutlass::library::RasterOrder const &raster_order, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Update performance result configuration for exploration parameters - void update_workspace_and_result_( - GroupedGemmWorkspace &gemm_workspace, - PerformanceResult &result, - ProblemSpace const &problem_space, - cutlass::library::RasterOrder const &raster_order, - std::array const &preferred_cluster, - std::array const &fallback_cluster, - int swizzle_size, - bool is_dynamic_cluster_enabled); - - /// Verifies CUTLASS against host and device references - bool verify_with_reference_( - Options const& options, - PerformanceReport& report, - DeviceContext& device_context, - library::Operation const* operation, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem, - cutlass::library::NumericTypeID element_A, - cutlass::library::NumericTypeID element_B); - - /// Method to profile a CUTLASS Operation - Status profile_cutlass_( - PerformanceResult& result, - Options const& options, - library::Operation const* operation, - void* arguments, - void* host_workspace, - void* device_workspace) override; - - /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape - bool profile_cutlass_for_fixed_shape_( - Options const& options, - library::Operation const* operation, - ProblemSpace const& problem_space); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h deleted file mode 100644 index 446ef2c16739b28aaf038ca62bad6e3cdf667813..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h +++ /dev/null @@ -1,287 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function -*/ - -#pragma once - -#include -#include -#include -#include - -// CUTLASS includes -#include "cutlass/trace.h" - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "performance_result.h" -#include "performance_report.h" -#include "problem_space.h" -#include "debug.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class OperationProfiler { -public: - - -protected: - // - // Data members - // - - /// Top-level operation kind - library::OperationKind kind_; - - /// Human readable description - std::string description_; - - /// Arguments parsed from command line - ArgumentDescriptionVector arguments_; - - /// List of providers used to verify and compare each result - ProviderVector verification_providers_; - - /// Model performance result initialized by the operation profiler with workload statistics - /// and reasonable default state. - PerformanceResult model_result_; - - /// Performance result vector constructed by profiling the operation - PerformanceResultVector results_; - -public: - - // - // Methods - // - - /// Ctor - OperationProfiler(); - - OperationProfiler( - Options const &options, - library::OperationKind kind, - ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), - ProviderVector const & verification_providers = ProviderVector()); - - /// Destructor - virtual ~OperationProfiler(); - - /// Obtains the operation kind - library::OperationKind kind() const { return kind_; } - - /// Gets the schema description - std::string const &description() const; - - /// Returns a reference to the arguments - ArgumentDescriptionVector const &arguments() const { return arguments_; } - -public: - - // - // Basic overrides - // - - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const =0; - - /// Entry point to profile all operations in the manifest - virtual int profile_all( - Options const &options, - library::Manifest const &manifest, - DeviceContext &device_context); - -public: - - // - // Operation-specific phases of verification and profiling - // - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) = 0; - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) = 0; - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) = 0; - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) = 0; - -public: - - // - // Static helpers - // - - /// Sleep for a given duration in ms - static void sleep(int sleep_duration); - - /// Returns true if the current operation description satisfies the problem space - static bool satisfies( - library::OperationDescription const &op_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Compares tensors for equality - static Disposition compare_tensors( - Options const &options, - DeviceAllocation &experimental, - DeviceAllocation &reference, - int64_t count = 0); - - static void save_workspace( - DeviceContext &device_context, - Options const &options, - library::OperationDescription const &desc, - library::Provider provider, - library::Provider verification_provider = library::Provider::kInvalid); - - /// Helper to set a performance result member - static void set_argument( - PerformanceResult &result, - char const *name, - ProblemSpace const &problem_space, - std::string const &value); - - /// Helper to set a performance result member - static void set_argument( - PerformanceResult &result, - char const *name, - ProblemSpace const &problem_space, - int64_t value); - -protected: - - /// Sets operation description - static void initialize_result_( - PerformanceResult &result, - library::OperationDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Method to profile an initialized CUTLASS operation - virtual Status profile_cutlass_( - PerformanceResult &result, - Options const &options, - library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace); - - /// Profiles the GPU kernel launched in `func` running simultaneously on all - /// requested devices. - Status profile_kernel_w_cuda_graphs_( - PerformanceResult& result, - Options const& options, - std::function const& func, - std::vector const& streams); - - Status profile_kernel_( - PerformanceResult& result, - Options const& options, - std::function const& func, - std::vector const& streams); - - /// Profiles the GPU kernel launched in `func` on the `stream` - Status profile_kernel_( - PerformanceResult& result, - Options const& options, - std::function const& func, - cudaStream_t stream = nullptr); - - /// Profiles the GPU kernel launched in `func` on the `stream` - Status profile_kernel_no_cuda_graphs_( - PerformanceResult& result, - Options const& options, - std::function const& func, - cudaStream_t stream = nullptr); - -private: - /// finds string matches filter_string in operation_name - bool find_string_matches_( - std::string const &filter_string, - std::string const &operation_name); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Vector of owning operation profilers -using OperationProfilerVector = std::vector>; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h deleted file mode 100644 index 1a957b36eea35f7c0a5366645c3a62298ca56dea..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h +++ /dev/null @@ -1,384 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Command line options for performance test program -*/ - -#pragma once - -#include -#include -#include - -#include - -#include "cutlass/util/command_line.h" -#include "cutlass/util/distribution.h" -#include "cutlass/library/library.h" - -#include "enumerated_types.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Global options -class Options { -public: - - /// Cublas and cuDNN options - struct Library { - - // - // Data members - // - - /// Algorithm mode - AlgorithmMode algorithm_mode; - - /// Algorithm enumerants - std::vector algorithms; - - // - // Methods - // - - explicit Library(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - }; - - /// Options related to the selected device - struct Device { - - /// Device ID - std::vector devices; - - /// Number of total devices - /// This is not set by the user, it is set by automatically - int num_devices; - - /// CUDA Device properties - std::vector properties; - - /// Total memory allocation on each device - size_t maximum_capacity; - - private: - /// SM Count - /// Limits the number of SMs to use on each device - int sm_count; - - // - // Methods - // - public: - explicit Device(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - void print_device_info(std::ostream &out) const; - - /// Returns the device ID from a device index - int device_id(size_t device_index) const; - - /// Returns the sm_count if set, otherwise returns the number of SMs on the device - int get_sm_count(int device_index) const; - - /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) - int compute_capability(int device_index) const; - }; - - /// Options related to initializing input tensors - struct Initialization { - - /// If true, data is initialized randomly. If false, no initialization is performed after - /// allocating tensors. - bool enabled; - - /// If true, data distribution is set by the user and is not allowed to change - /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) - bool fix_data_distribution; - - /// Data distribution for input tensors - Distribution data_distribution; - - /// Source of random tensor elements - library::Provider provider; - - /// Random number generator seed. - int seed; - - // - // Methods - // - - explicit Initialization(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - - /// Helper to parse a Distribution object from the command line parser - static void get_distribution( - cutlass::CommandLine const &args, - std::string const &arg, - cutlass::Distribution &dist); - }; - - /// Options related to verification of the result - struct Verification { - - // - // Data members - // - - /// If true, kernels are verified before they are profiled - bool enabled; - - /// If true, causes profiler to return an error code if no reference check is run. - /// Only valid when verification is enabled. - bool required; - - /// Relative error threshold - zero to require bit-level consistency - double epsilon; - - /// Values smaller than this are assumed to be zero - double nonzero_floor; - - /// List of providers used to verify each result - ProviderVector providers; - - /// Indicates when to save the workspace - SaveWorkspace save_workspace; - - // - // Methods - // - - explicit Verification(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - - /// Returns true if a provider is enabled - bool provider_enabled(library::Provider provider) const; - - /// Returns the index of a provider if its enabled - size_t index(library::Provider provider) const; - }; - - /// Options related to profiling - struct Profiling { - - /// Number of workspaces to rotate through to avoid cache-resident working sets - int workspace_count{0}; - - /// Number of iterations to warmup each kernel prior to profiling - int warmup_iterations{10}; - - /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration - /// This will always override profiling-duration and min-iterations. - int iterations{100}; - - /// Time to spend profiling each kernel (ms) - int duration{10}; - - /// Minimum number of iterations to profile - int min_iterations{10}; - - /// If true, profiling with cuda graph enabled. - bool use_cuda_graphs{false}; - - /// If enabled, the CUTLASS profiler searches for the best-performing kernel - /// within the subset of kernels matching a kernel filter regex. The best - /// performance is determined by screening over a set of predefined M/N/K - /// sizes and performance-related parameters, including cluster shapes, - /// swizzle sizes, and rasterization orders. - /// For now, it only supports legacy GEMM and blockscaled GEMM. - bool enable_kernel_performance_search{false}; - - /// If enabled, the CUTLASS profiler searches for the best-performing kernel - /// for a given M/N/K problem size by evaluating various performance-related - /// parameters such as cluster shapes, swizzle sizes, and rasterization orders. - /// For now, it only supports legacy GEMM and blockscaled GEMM. - bool enable_best_kernel_for_fixed_shape{false}; - - /// Number of ms to sleep between profiling periods (ms) - int sleep_duration{50}; - - /// If true, profiling is actually conducted. - bool enabled{true}; - - /// If true, profiling returns an error code if no kernels are found to match the filters. - bool error_on_no_match{false}; - - /// If true, profiling returns an error code if no kernel are profiled - // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) - bool error_if_nothing_is_profiled{false}; - - /// List of providers of each functionality to be profiled - ProviderVector providers; - - // - // Methods - // - - explicit Profiling(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - - /// Returns true if a provider is enabled - bool provider_enabled(library::Provider provider) const; - - /// Returns the index of a provider if its enabled - size_t index(library::Provider provider) const; - }; - - /// Options related to reporting - struct Report { - - /// If true, result is appended to possibly existing file - bool append; - - /// Path to a file containing results - std::string output_path; - - /// Path to a file containing junit xml results - std::string junit_output_path; - - /// Sequence of tags to attach to each result - std::vector> pivot_tags; - - /// If true, reports status of all kernels including those that were - /// not run for the given arguments - bool report_not_run; - - /// Prints human-readable text to stdout. If false, nothing is written to stdout - bool verbose; - - /// Sort results by flops-per-byte - bool sort_flops_per_byte; - - /// Sort results by flops-per-second - bool sort_flops_per_sec; - - /// Prints the name of the kernel being profiled before running the kernel. - /// This is useful for determining which kernel is causing a run of the profiler to hang - bool print_kernel_before_running; - - // - // Methods - // - - explicit Report(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - }; - - /// Options related to printing usage and version information - struct About { - - /// If true, usage is printed and the program ends. - bool help; - - /// Prints version string - bool version; - - /// Print information about devices - bool device_info; - - // - // Methods - // - - explicit About(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out, int indent = 0) const; - - static void print_version(std::ostream &out); - }; - -public: - - // - // Data members - // - - /// Top-level execution mode - ExecutionMode execution_mode; - - /// Name of math function to profile - library::OperationKind operation_kind; - - /// Vector of operation name substrings - std::vector operation_names; - - /// Map of problems to run for each operation - /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] - std::unordered_map> operation_problems; - - /// Vector of operation name substrings - std::vector excluded_operation_names; - - - // - // Detailed configuration options - // - - /// Configuration - CommandLine cmdline; - Device device; - Initialization initialization; - Library library; - Verification verification; - Profiling profiling; - Report report; - About about; - -public: - - explicit Options(CommandLine const &cmdline); - - void print_usage(std::ostream &out) const; - void print_options(std::ostream &out) const; - - static std::string indent_str(int indent); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h deleted file mode 100644 index 07102c99bc0f38a071e1ab828aab30678a3e2d44..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h +++ /dev/null @@ -1,128 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Class performing output during profiling -*/ - -#pragma once - -#include -#include - -// CUTLASS Profiler includes -#include "options.h" -#include "enumerated_types.h" -#include "performance_result.h" - -// CUTLASS Library includes -#include "cutlass/library/library.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -class PerformanceReport { -private: - - /// Reference to options - Options const &options_; - - /// Operation kind - library::OperationKind op_kind_; - - /// Operation file name containing performance report of op_kind - std::string op_file_name_; - - /// Output file containing results - std::ofstream output_file_; - - /// Operation file name containing junit performance report of op_kind - std::string op_junit_file_name_; - - /// Output file containing junit results - std::ofstream junit_output_file_; - - /// Flag indicating the performance report is valid - bool good_; - - /// Vector of argument names - std::vector argument_names_; - - /// Counter uniquely identifying problem within the report - size_t problem_index_; - - /// Collection of all results - PerformanceResultVector concatenated_results_; - -public: - - PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); - ~PerformanceReport(); - - bool good() const { return good_; } - - void next_problem(); - void append_result(PerformanceResult result); - void sort_flops_per_byte(PerformanceResultVector &results); - void sort_flops_per_sec(PerformanceResultVector &results); - void append_results(PerformanceResultVector const &results); - -public: - - /// Prints the CSV header - std::ostream & print_csv_header_(std::ostream &out); - - /// Prints the CSV - std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); - - /// @defgroup jUnit Result Generation - /// Functions related to generation of the jUnit results - /// @{ - - std::ostream & print_junit_header_(std::ostream &out); - std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); - std::ostream & print_junit_footer_(std::ostream &out); - - /// @} - - /// Prints the result in human readable form - std::ostream & print_result_pretty_( - std::ostream &out, - PerformanceResult const &result, - bool use_shell_coloring = true); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h deleted file mode 100644 index 986ac89bc86a267ce8fb181a986f28f3f0936566..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h +++ /dev/null @@ -1,137 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function -*/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" - -// CUTLASS Profiler includes -#include "enumerated_types.h" - -// CUTLASS Library includes -#include "cutlass/library/library.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Performance result object -struct PerformanceResult { - - /// Index of problem - size_t problem_index; - - /// library::Provider - library::Provider provider; - - /// Operation kind - library::OperationKind op_kind; - - /// CUTLASS status result from kernels (success or failure) - // Status does information on verification - Status status; - - /// Outcome of verification (worst case verification result) - Disposition disposition; - - /// Outcome of verification (all verification results) - DispositionMap verification_map; - - /// Operation name - std::string operation_name; - - /// Stringified vector of argument values - std::vector > arguments; - - /// Number of bytes read or written - int64_t bytes; - - /// Number of DL flops performed by the math function - int64_t flops; - - /// Average runtime in ms - double runtime; - - /// Average runtime in ms per device - std::vector runtime_vector; - - // - // Members - // - - /// Ctor - PerformanceResult(): - problem_index(0), - op_kind(library::OperationKind::kInvalid), - provider(library::Provider::kInvalid), - disposition(Disposition::kNotRun), - status(Status::kInvalid), - bytes(0), - flops(0), - runtime(0) - { } - - // Copy constructor for deep copy - PerformanceResult(const PerformanceResult& other) = default; - - // Explicitly define copy assignment operator - PerformanceResult& operator=(const PerformanceResult& other) = default; - - /// Returns true if the runtime is valid - bool good() const { - return runtime > 0; - } - - /// Math throughput in units of GFLOP/s - double gflops_per_sec() const { - return double(flops) / runtime / 1.0e6; - } - - /// memory bandwidth in units of GiB/s - double gbytes_per_sec() const { - return double(bytes) / double(1 << 30) / runtime * 1000.0; - } - -}; - -using PerformanceResultVector = std::vector; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h deleted file mode 100644 index 9bdbec657c10cff0dafebd2cb6cd52057f3695c9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h +++ /dev/null @@ -1,1039 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief - - "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, - bug-ridden, slow implementation of half of Common Lisp." - - - Greenspun's Tenth Rule of Programming - - - cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian - product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. - - These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, - verify and profile various operations when they are compatible with the command line, and - construct data tables of results that are convenient inputs to post processing in Excel or Pandas. - - By executing multiple problems per invocation, startup overheads may be amortized across many - kernel launches. -*/ - -#pragma once - -// Standard Library includes -#include -#include -#include -#include -#include - -// CUTLASS Utility includes -#include "cutlass/util/command_line.h" - -// CUTLASS Library includes -#include "cutlass/library/library.h" - -// Profiler includes -#include "enumerated_types.h" - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines the argument schema -struct ArgumentDescription { - - /// Type of argument - ArgumentTypeID type; - - /// Prioritized array of aliases used in command line parsing - std::vector aliases; - - /// Description of argument - std::string description; - - // - // Methods - // - - /// Default ctor - ArgumentDescription(): - type(ArgumentTypeID::kInvalid) { } - - /// Constructor with aliases - ArgumentDescription( - ArgumentTypeID type_, - std::vector const &aliases_, - std::string const &description_ - ): - type(type_), aliases(aliases_), description(description_) { } -}; - -/// Vector of arguments -using ArgumentDescriptionVector = std::vector; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Base class for kernel arguments -struct KernelArgument { - - // - // Type definitions - // - - /// Value base class - struct Value { - - KernelArgument const *argument; - bool not_null; - - // - // Methods - // - - Value( - KernelArgument const *argument_ = nullptr, - bool not_null_ = true - ): argument(argument_), not_null(not_null_) { } - - virtual ~Value() { } - - virtual std::ostream &print(std::ostream &out) const =0; - }; - - /// Abstract base class to iterate over values within arguments - struct ValueIterator { - - /// Indicates type of kernel argument - KernelArgument const *argument; - - /// If the iterator points to an argument that is null, it needs to be distinguished - /// from end. - bool null_argument; - - // - // Methods - // - - /// Constructs a value iterator - no methods are valid if argument_ == nullptr - ValueIterator( - KernelArgument const *argument_ = nullptr, - bool null_argument_ = false): - argument(argument_), null_argument(null_argument_) { - - if (!argument_->not_null()) { - null_argument = true; - } - } - - virtual ~ValueIterator() { } - - /// Advances to next point in range - virtual void operator++() = 0; - - /// Compares against another value iterator - must be of the same KernelArgument type - virtual bool operator==(ValueIterator const &it) const = 0; - - /// Returns a unique_ptr object pointing to a newly created value object - virtual std::unique_ptr at() const = 0; - - /// Gets the type of the iterator - ArgumentTypeID type() const { - return argument->description->type; - } - - /// Helper to compute inequality - bool operator!=(ValueIterator const &it) const { - return !(*this == it); - } - - std::ostream &print(std::ostream &out) const; - }; - - // - // Data members - // - - /// Describes the argument - ArgumentDescription const *description; - - /// Parent node - KernelArgument *parent; - - /// Sequence in which the kernel argument is to be iterated over. - /// Smaller means faster changing. -1 is don't care - int ordinal; - - // - // Methods - // - - /// Default ctor - KernelArgument( - ArgumentDescription const *description_ = nullptr, - KernelArgument *parent_ = nullptr, - int ordinal_ = -1 - ): description(description_), parent(parent_), ordinal(ordinal_) { } - - virtual ~KernelArgument(); - - /// Returns true if the kernel argument iself is empty - virtual bool not_null() const =0; - - /// Returns a string name for debugging - std::string qualified_name() const { - if (description) { - if (description->aliases.empty()) { - return ""; - } - return description->aliases.front(); - } - return ""; - } - - virtual std::unique_ptr begin() const =0; - virtual std::unique_ptr end() const =0; -}; - -using KernelArgumentVector = std::vector>; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel -/// type. -struct ScalarArgument : public KernelArgument { - - // - // Type definitions - // - - /// Value type - struct ScalarValue : public KernelArgument::Value { - - std::string value; - - // - // Methods - // - - ScalarValue( - std::string const &value_ = "", - ScalarArgument const *argument = nullptr, - bool not_null_ = true - ); - - virtual std::ostream &print(std::ostream &out) const; - }; - - using ValueCollection = std::vector; - - /// Abstract base class to iterate over values within arguments - struct ScalarValueIterator : public KernelArgument::ValueIterator { - - // - // Data members - // - - ValueCollection::const_iterator value_it; - - // - // Methods - // - - explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); - - virtual void operator++(); - virtual bool operator==(ValueIterator const &it) const; - - /// Gets the value pointed to - virtual std::unique_ptr at() const; - }; - - // - // Data members - // - - /// Set of possible values - ValueCollection values; - - // - // Methods - // - - /// Default ctor - explicit ScalarArgument( - ArgumentDescription const *description - ): - KernelArgument(description) { } - - virtual bool not_null() const { - return !values.empty(); - } - - virtual std::unique_ptr begin() const; - virtual std::unique_ptr end() const; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Closed range supporting additive increment -struct Range { - - // - // Type definitions - // - - enum class Mode { - kSequence, - kRandom, - kRandomLog2, - kInvalid - }; - - struct Iterator { - - int64_t value; - int64_t increment; - Range const *range; - - // - // Methods - // - - Iterator( - int64_t value_ = 0, - int64_t increment_ = 1, - Range const *range_ = nullptr - ): - value(value_), increment(increment_), range(range_) { } - - Iterator & operator++() { - value += increment; - return *this; - } - - Iterator operator++(int) { - Iterator self(*this); - ++(*this); - return self; - } - - bool operator==(Iterator const &it) const { - return value == it.value; - } - - bool operator!=(Iterator const &it) const { - return !(*this == it); - } - - static int64_t round(int64_t value, int64_t divisible) { - int64_t rem = (value % divisible); - - // Round either up or down - if (rem > divisible / 2) { - value += (divisible - rem); - } - else { - value -= rem; - } - - return value; - } - - int64_t at() const { - if (!range) { - return value; - } - - switch (range->mode) { - case Mode::kSequence: return value; - - case Mode::kRandom: { - double rnd = double(range->minimum) + - double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); - - int64_t value = int64_t(rnd); - - return round(value, range->divisible); - } - break; - - case Mode::kRandomLog2: { - double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); - double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); - double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); - - int64_t value = int64_t(std::pow(2.0, rnd)); - - return round(value, range->divisible); - } - break; - default: break; - } - return value; - } - - int64_t operator*() const { - return at(); - } - }; - - // - // Data members - // - - int64_t first; ///< first element in range - int64_t last; ///< last element in range - int64_t increment; ///< additive increment between values - - Mode mode; ///< mode selection enables alternative values - int64_t minimum; ///< minimum value to return - int64_t maximum; ///< maximum value to return - int64_t divisible; ///< rounds value down to an integer multiple of this value - - // - // Methods - // - - /// Default constructor - range acts as a scalar - Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } - - /// Range acts as a range - Range( - int64_t first_, - int64_t last_, - int64_t increment_ = 1, - Mode mode_ = Mode::kSequence, - int64_t minimum_ = 0, - int64_t maximum_ = 0, - int64_t divisible_ = 1 - ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { - - // Helpers to avoid constructing invalid ranges - if (increment > 0) { - if (last < first) { - std::swap(last, first); - } - } - else if (increment < 0) { - if (first < last) { - std::swap(last, first); - } - } - else if (last != first) { - last = first; - increment = 1; - } - } - - /// Helper to construct a sequence range - static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { - return Range(first_, last_, increment_, Mode::kSequence); - } - - /// Helper to construct a range that is a random distribution - static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { - return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); - } - - /// Helper to construct a range that is a random distribution over a log scale - static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { - return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); - } - - /// Returns an iterator to the first element within the range - Iterator begin() const { - return Iterator(first, increment, this); - } - - /// Returns an iterator to the first element *after* the range - Iterator end() const { - return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); - } -}; - -/// Integer-valued argument - represented as a list of integer-valued ranges -struct IntegerArgument : public KernelArgument { - - // - // Type definitions - // - - /// Value type - struct IntegerValue : public KernelArgument::Value { - - int64_t value; - - // - // Methods - // - - IntegerValue( - int64_t value_ = 0, - IntegerArgument const *argument_ = nullptr, - bool not_null_ = true - ); - - /// Pretty printer for debugging - virtual std::ostream &print(std::ostream &out) const; - }; - - /// Collection of ranges represent the IntegerArgument's state - using RangeCollection = std::vector; - - /// Abstract base class to iterate over values within arguments - struct IntegerValueIterator : public KernelArgument::ValueIterator { - - // - // Data members - // - - RangeCollection::const_iterator range_it; - Range::Iterator value_it; - - // - // Methods - // - - IntegerValueIterator(); - IntegerValueIterator(IntegerArgument const *argument); - - virtual void operator++(); - virtual bool operator==(ValueIterator const &it) const; - - /// Gets the value pointed to - virtual std::unique_ptr at() const; - }; - - // - // Data members - // - - /// Set of possible values - RangeCollection ranges; - - // - // Methods - // - - /// Default ctor - IntegerArgument( - ArgumentDescription const *description - ): - KernelArgument(description) { } - - virtual bool not_null() const { - bool _not_null = !ranges.empty(); - return _not_null; - } - - virtual std::unique_ptr begin() const; - virtual std::unique_ptr end() const; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure defining the data type of tensors -struct TensorArgument : public KernelArgument { - - // - // Type definitions - // - - struct TensorDescription { - - /// Data type of elements - library::NumericTypeID element; - - /// Layout definition - library::LayoutTypeID layout; - - /// Computed extent - std::vector extent; - - /// Enables directly specifying stride value used to size tensor - std::vector stride; - - // - // Methods - // - - TensorDescription( - library::NumericTypeID element_ = library::NumericTypeID::kUnknown, - library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, - std::vector extent_ = std::vector(), - std::vector stride_ = std::vector() - ): - element(element_), layout(layout_), extent(extent_), stride(stride_) {} - }; - - using ValueCollection = std::vector; - - /// Value structure - struct TensorValue : public KernelArgument::Value { - - TensorDescription desc; - - // - // Methods - // - - TensorValue( - TensorDescription const &desc_ = TensorDescription(), - TensorArgument const *argument_ = nullptr, - bool not_null_ = true - ); - - /// Pretty printer for debugging - virtual std::ostream &print(std::ostream &out) const; - }; - - /// Abstract base class to iterate over values within arguments - struct TensorValueIterator : public KernelArgument::ValueIterator { - - // - // Data members - // - - ValueCollection::const_iterator value_it; - - // - // Methods - // - - explicit TensorValueIterator(TensorArgument const *argument_); - - virtual void operator++(); - virtual bool operator==(ValueIterator const &it) const; - - /// Gets the value pointed to - virtual std::unique_ptr at() const; - }; - - /// Set of possible values - ValueCollection values; - - // - // Methods - // - - /// Default ctor - explicit TensorArgument( - ArgumentDescription const *description - ): - KernelArgument(description) { } - - virtual bool not_null() const { - return !values.empty(); - } - - virtual std::unique_ptr begin() const; - virtual std::unique_ptr end() const; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Numeric data type -struct EnumeratedTypeArgument : public KernelArgument { - - // - // Type definitions - // - - struct EnumeratedTypeValue : public KernelArgument::Value { - - /// Data type of element - std::string element; - - // - // Methods - // - - EnumeratedTypeValue( - std::string const &element_ = std::string(), - EnumeratedTypeArgument const *argument_ = nullptr, - bool not_null_ = true - ); - - /// Pretty printer for debugging - virtual std::ostream &print(std::ostream &out) const; - }; - - using ValueCollection = std::vector; - - /// Abstract base class to iterate over values within arguments - struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { - - // - // Data members - // - - ValueCollection::const_iterator value_it; - - // - // Methods - // - - explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); - - virtual void operator++(); - virtual bool operator==(ValueIterator const &it) const; - - /// Gets the value pointed to - virtual std::unique_ptr at() const; - }; - - // - // Data members - // - - ValueCollection values; - - // - // Members - // - - /// Default ctor - explicit EnumeratedTypeArgument(ArgumentDescription const *description): - KernelArgument(description) {} - - virtual bool not_null() const { - return !values.empty(); - } - - virtual std::unique_ptr begin() const; - virtual std::unique_ptr end() const; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Object storing the space argument values -class ProblemSpace { -public: - - /// Tuple of arguments - using Problem = std::vector>; - - /// Type used to iterator over things - using IteratorVector = std::vector>; - - /// Iterates over points in the design space - class Iterator { - private: - - /// One iterator per argument - IteratorVector iterators; - - public: - - // - // Methods - // - - explicit Iterator(); - Iterator(ProblemSpace const &problem_space); - Iterator(Iterator &&it); - - // Rule of three - Iterator(Iterator const &) = delete; - Iterator &operator=(Iterator const &it) = delete; - ~Iterator() = default; - - /// Pre-increment - advances to next point in argument range - void operator++(); - - /// Gets the current argument value - Problem at() const; - - /// Moves iterator to end - void move_to_end(); - - /// Equality operator - bool operator==(Iterator const &it) const; - - /// Inequality operator - bool operator!=(Iterator const &it) const { - return !(*this == it); - } - - /// Helper to call at() method - Problem operator*() const { - return at(); - } - - /// Helper to print iterator state - std::ostream & print(std::ostream &out) const; - - private: - - /// Helper for recursively constructing iterators - void construct_(KernelArgument const *argument); - }; - -public: - - // - // Data members - // - - KernelArgumentVector arguments; - - /// Map of argument names to their position within the argument vector - std::unordered_map argument_index_map; - -public: - - // - // Methods - // - - /// Default ctor - ProblemSpace() = default; - - /// Constructs a problem space from a vector of arguments. This vector must outlive - /// the ProblemSpace object, which stores pointers to objects within the - /// ArgumentDescriptionVector. - ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); - - Iterator begin() const; // returns an iterator to the first point in the range - Iterator end() const; // returns an iterator to the first point after the range - - /// Returns the index of an argument by name - size_t argument_index(char const *name) const; - - /// Gets all argument names as an ordered vector - std::vector argument_names() const; - - /// Returns the number of dimensions of the problem space - size_t rank() const { return arguments.size(); } - -private: - - /// Helper for recursively cloning - void clone_( - KernelArgumentVector &kernel_args, - ArgumentDescription const *arg_desc); - - /// Parses command line argument - void parse_( - KernelArgument *arg, - CommandLine const &cmdline); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Lexically casts an argument to an int if it is defined. Returns true if not null. -bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_int( - int &int_value, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_int( - int64_t &int_value, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -bool arg_as_bool(bool &bool_value, KernelArgument::Value const *value_ptr); - -bool arg_as_bool(bool &bool_value, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_NumericTypeID( - library::NumericTypeID &numeric_type, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_LayoutTypeID( - library::LayoutTypeID &layout_type, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_OpcodeClassID( - library::OpcodeClassID &opcode_class, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_SplitKModeID( - library::SplitKMode &split_k_mode, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_ConvModeID( - library::ConvModeID &conv_mode, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_IteratorAlgorithmID( - library::IteratorAlgorithmID &iterator_algorithm, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_RuntimeDatatype( - library::RuntimeDatatype &runtime_datatype, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_RasterOrder( - library::RasterOrder &raster_order, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -bool arg_as_ProviderID( - library::Provider &provider, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -bool arg_as_scalar( - std::vector &bytes, - library::NumericTypeID numeric_type, - KernelArgument::Value const *value_ptr); - -/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -bool arg_as_scalar( - std::vector &bytes, - library::NumericTypeID numeric_type, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -bool arg_as_string( - std::string& arg, - char const* name, - ProblemSpace const& problem_space, - ProblemSpace::Problem const& problem); - -/// Returns true if a tensor description satisfies a `tensor` value -bool tensor_description_satisfies( - library::TensorDescription const &tensor_desc, - TensorArgument::TensorValue const *value_ptr); - -/// Returns true if a tensor description satisfies a `tensor` value -bool tensor_description_satisfies( - library::TensorDescription const &tensor_desc, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - -/// Returns true if a conv kind satisfies the value -bool conv_kind_satisfies( - library::ConvKind const &conv_kind, - EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); - -/// Returns true if a conv kind satisfies the value -bool conv_kind_satisfies( - library::ConvKind const &conv_kind, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -/// Returns true if a iterator algorithm satisfies the value -bool iterator_algorithm_satisfies( - library::IteratorAlgorithmID const &iterator_algorithm, - EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); - -/// Returns true if a iterator algorithm satisfies the value -bool iterator_algorithm_satisfies( - library::IteratorAlgorithmID const &iterator_algorithm, - char const *name, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h deleted file mode 100644 index ba47a6832077984c334a5467257a151735b088b3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h +++ /dev/null @@ -1,229 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function - - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/blas3.h" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Abstract base class for each math function -class Rank2KOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct RankKProblem { - int64_t n; - int64_t k; - int64_t lda; - int64_t ldb; - int64_t ldc; - FillMode fill_mode; - BlasMode blas_mode; - std::vector alpha; - std::vector beta; - int64_t split_k_slices; - int64_t batch_count; - - // - // Methods - // - - RankKProblem(): - n(16), k(16), lda(0), ldc(0), - fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), - split_k_slices(1), batch_count(1) { } - - /// Parses the problem - Status parse( - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Total number of bytes loaded - int64_t bytes(library::RankKDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::RankKDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct RankKWorkspace { - - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *C; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - library::RankKConfiguration configuration; - library::RankKArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - RankKWorkspace(): - A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - RankKProblem problem_; - - /// Device memory allocations - RankKWorkspace rank_k_workspace_; - - -public: - // - // Methods - // - - /// Ctor - Rank2KOperationProfiler(Options const &options); - - /// Destructor - virtual ~Rank2KOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h deleted file mode 100644 index fff190a7570cd5811c6e5de6284bf96e40c404b7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h +++ /dev/null @@ -1,227 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function - - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/blas3.h" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Abstract base class for each math function -class RankKOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct RankKProblem { - int64_t n; - int64_t k; - int64_t lda; - int64_t ldc; - FillMode fill_mode; - BlasMode blas_mode; - std::vector alpha; - std::vector beta; - int64_t split_k_slices; - int64_t batch_count; - - // - // Methods - // - - RankKProblem(): - n(16), k(16), lda(0), ldc(0), - fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), - split_k_slices(1), batch_count(1) { } - - /// Parses the problem - Status parse( - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Total number of bytes loaded - int64_t bytes(library::RankKDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::RankKDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct RankKWorkspace { - - DeviceAllocation *A; - DeviceAllocation *C; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - library::RankKConfiguration configuration; - library::RankKArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - RankKWorkspace(): - A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - RankKProblem problem_; - - /// Device memory allocations - RankKWorkspace rank_k_workspace_; - - -public: - // - // Methods - // - - /// Ctor - RankKOperationProfiler(Options const &options); - - /// Destructor - virtual ~RankKOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::RankKDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h deleted file mode 100644 index 0c81ef4637175a6de1f44cedddf319436aaff24d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h +++ /dev/null @@ -1,173 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines profiling functionality for reduction operation - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#if CUTLASS_ENABLE_CUDNN -#include "cudnn_helpers.h" -#endif //#if CUTLASS_ENABLE_CUDNN -#include "debug.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class ReductionOperationProfiler : public OperationProfiler { -public: - - - /// Workspace used - struct ReductionWorkspace { - - /// Conv device allocations - DeviceAllocation *Workspace; - DeviceAllocation *Source; - DeviceAllocation *Destination; - DeviceAllocation *Reference; - - /// Library configuration and arguments - library::ReductionConfiguration configuration; - library::ReductionArguments arguments; - - /// Buffer used for the cutlass operations' host workspace - std::vector host_workspace; - - /// Buffer used for the cutlass operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - ReductionWorkspace(): - Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - /// Reduction problem obtained from problem space - MatrixCoord problem_; - - /// Device memory allocations - ReductionWorkspace conv_workspace_; - - -public: - // - // Methods - // - - /// Ctor - ReductionOperationProfiler(Options const &options); - - /// Destructor - virtual ~ReductionOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h deleted file mode 100644 index 60204d8c9d458ab12020a6492de23174739aa584..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h +++ /dev/null @@ -1,214 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" -#include "gemm_operation_profiler.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class SparseGemmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct SparseGemmProblem { - int64_t m; - int64_t n; - int64_t k; - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t lde; - std::vector alpha; - std::vector beta; - int64_t split_k_slices; - int64_t batch_count; - static int const sparse = 2; - // every 128b ElementA uses one elementE - int elements_per_128b; - - // - // Methods - // - - SparseGemmProblem(): - m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } - - /// Parses the problem - Status parse( - library::SparseGemmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::SparseGemmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct SparseGemmWorkspace { - - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *C; - DeviceAllocation *E; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - library::SparseGemmConfiguration configuration; - library::SparseGemmArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - SparseGemmWorkspace(): - A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - // GEMM problem - SparseGemmProblem problem_; - - /// Device memory allocations - SparseGemmWorkspace gemm_workspace_; - - -public: - // - // Methods - // - - /// Ctor - SparseGemmOperationProfiler(Options const &options); - - /// Destructor - virtual ~SparseGemmOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::SparseGemmDescription const &operation_desc, - ProblemSpace const &problem_space); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h deleted file mode 100644 index 94ded5e803bf914e5ae8c4ebb867cfe42ef829bc..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h +++ /dev/null @@ -1,230 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function - - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/blas3.h" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -/// Abstract base class for each math function -class SymmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct SymmProblem { - int64_t m; - int64_t n; - int64_t lda; - int64_t ldb; - int64_t ldc; - SideMode side_mode; - FillMode fill_mode; - BlasMode blas_mode; - std::vector alpha; - std::vector beta; - int64_t split_k_slices; - int64_t batch_count; - - // - // Methods - // - - SymmProblem(): - m(16), n(16), lda(0), ldb(0), ldc(0), - side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), - split_k_slices(1), batch_count(1) { } - - /// Parses the problem - Status parse( - library::SymmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Total number of bytes loaded - int64_t bytes(library::SymmDescription const &operation_desc) const; - - /// Total number of flops computed - int64_t flops(library::SymmDescription const &operation_desc) const; - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::SymmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct SymmWorkspace { - - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *C; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - library::SymmConfiguration configuration; - library::SymmArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - SymmWorkspace(): - A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - SymmProblem problem_; - - /// Device memory allocations - SymmWorkspace symm_workspace_; - - -public: - // - // Methods - // - - /// Ctor - SymmOperationProfiler(Options const &options); - - /// Destructor - virtual ~SymmOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::SymmDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h deleted file mode 100644 index 9f21dafa0ecc869840fdba0a9c4414a89bbf4a7d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h +++ /dev/null @@ -1,222 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines a math function - - -*/ - -#pragma once - -#include -#include -#include -#include -#include - -// CUTLASS Library includes -#include "cutlass/blas3.h" -#include "cutlass/library/library.h" -#include "cutlass/library/util.h" -#include "cutlass/library/manifest.h" - -// Profiler includes -#include "options.h" -#include "device_context.h" -#include "operation_profiler.h" -#include "performance_result.h" -#include "problem_space.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace profiler { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Abstract base class for each math function -class TrmmOperationProfiler : public OperationProfiler { -public: - - /// Problem structure obtained from problem space - struct TrmmProblem { - int64_t m; - int64_t n; - int64_t lda; - int64_t ldb; - int64_t ldd; - SideMode side_mode; - FillMode fill_mode; - DiagType diag_type; - std::vector alpha; - std::vector beta; - int64_t split_k_slices; - int64_t batch_count; - - // - // Methods - // - - TrmmProblem(): - m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } - - /// Parses the problem - Status parse( - library::TrmmDescription const &operation_desc, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes a performance result - void initialize_result( - PerformanceResult &result, - library::TrmmDescription const &operation_desc, - ProblemSpace const &problem_space); - }; - - /// Workspace used - struct TrmmWorkspace { - - DeviceAllocation *A; - DeviceAllocation *B; - DeviceAllocation *D; - DeviceAllocation *Computed; - DeviceAllocation *Reference; - - library::TrmmConfiguration configuration; - library::TrmmArguments arguments; - - /// Buffer used for the operation's host workspace - std::vector host_workspace; - - /// Buffer used for the operations' device workspace - DeviceAllocation device_workspace; - - // - // Methods - // - - TrmmWorkspace(): - A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } - }; - -protected: - - // - // Data members - // - - /// GEMM problem obtained from problem space - TrmmProblem problem_; - - /// Device memory allocations - TrmmWorkspace trmm_workspace_; - - -public: - // - // Methods - // - - /// Ctor - TrmmOperationProfiler(Options const &options); - - /// Destructor - virtual ~TrmmOperationProfiler(); - - /// Prints usage statement for the math function - virtual void print_usage(std::ostream &out) const; - - /// Prints examples - virtual void print_examples(std::ostream &out) const; - - /// Extracts the problem dimensions - virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Initializes workspace - virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Verifies CUTLASS against references - virtual bool verify_cutlass( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - - /// Measures performance results - virtual bool profile( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -protected: - - /// Initializes the performance result - void initialize_result_( - PerformanceResult &result, - Options const &options, - library::TrmmDescription const &operation_desc, - ProblemSpace const &problem_space); - - /// Verifies CUTLASS against references - bool verify_with_cublas_( - Options const &options, - PerformanceReport &report, - DeviceContext &device_context, - library::Operation const *operation, - ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace profiler -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp deleted file mode 100644 index c2727c989e645eca8e67a5d8d50391ced803cffa..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp +++ /dev/null @@ -1,67 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include - -struct GPU_Clock -{ - GPU_Clock() { - cudaEventCreate(&start_); - cudaEventCreate(&stop_); - cudaEventRecord(start_); - } - - ~GPU_Clock() { - cudaEventDestroy(start_); - cudaEventDestroy(stop_); - } - - void start() { - cudaEventRecord(start_); - } - - float milliseconds() { - cudaEventRecord(stop_); - cudaEventSynchronize(stop_); - float time; - cudaEventElapsedTime(&time, start_, stop_); - return time; - } - - float seconds() { - return milliseconds() * float(1e-3); - } - - private: - cudaEvent_t start_, stop_; -}; diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h deleted file mode 100644 index c95bd1cbeb56cc566394b155ea7ac24f07c28162..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h +++ /dev/null @@ -1,324 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * Utility for parsing command line arguments - */ - -#include -#include -#include -#include -#include -#include - -#include - -#include "cutlass/cutlass.h" - -namespace cutlass { - -/****************************************************************************** - * command_line - ******************************************************************************/ - -/** - * Utility for parsing command line arguments - */ -struct CommandLine { - std::vector keys; - std::vector values; - std::vector args; - - /** - * Constructor - */ - CommandLine(int argc, const char** argv) { - using namespace std; - - for (int i = 1; i < argc; i++) { - string arg = argv[i]; - - if ((arg[0] != '-') || (arg[1] != '-')) { - args.push_back(arg); - continue; - } - - string::size_type pos; - string key, val; - if ((pos = arg.find('=')) == string::npos) { - key = string(arg, 2, arg.length() - 2); - val = ""; - } else { - key = string(arg, 2, pos - 2); - val = string(arg, pos + 1, arg.length() - 1); - } - - keys.push_back(key); - values.push_back(val); - } - } - - /** - * Constructor to represent a command line from a map of [argument] -> [value] - */ - CommandLine(std::unordered_map& arg_map) { - for (const auto& [key, value] : arg_map) { - keys.push_back(key); - values.push_back(value); - } - } - - /** - * Checks whether a flag "--" is present in the commandline - */ - bool check_cmd_line_flag(const char* arg_name) const { - using namespace std; - - for (int i = 0; i < int(keys.size()); ++i) { - if (keys[i] == string(arg_name)) return true; - } - return false; - } - - /** - * Returns number of naked (non-flag and non-key-value) commandline parameters - */ - size_t num_naked_args() const { - return args.size(); - } - - /** - * Print naked (non-flag and non-key-value) commandline parameters - */ - void print_naked_args(std::ostream &out) const { - for (auto arg : args) { - out << " " << arg <<"\n"; - } - } - - /** - * Returns the commandline parameter for a given index (not including flags) - */ - template - void get_cmd_line_argument(size_t index, value_t& val) const { - using namespace std; - if (index < args.size()) { - istringstream str_stream(args[index]); - str_stream >> val; - } - } - - /** - * Obtains the boolean value specified for a given commandline parameter --= - */ - void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { - val = _default; - if (check_cmd_line_flag(arg_name)) { - std::string value; - get_cmd_line_argument(arg_name, value); - - val = !(value == "0" || value == "false"); - } - } - - /** - * Obtains the value specified for a given commandline parameter --= - */ - template - void get_cmd_line_argument(const char* arg_name, - value_t& val) const { - - get_cmd_line_argument(arg_name, val, val); - } - - /** - * Obtains the value specified for a given commandline parameter --= - */ - template - void get_cmd_line_argument(const char* arg_name, - value_t& val, - value_t const& _default) const { - using namespace std; - - val = _default; - - for (int i = 0; i < int(keys.size()); ++i) { - if (keys[i] == string(arg_name)) { - istringstream str_stream(values[i]); - str_stream >> val; - } - } - } - - /** - * Returns the values specified for a given commandline parameter --=,* - */ - template - void get_cmd_line_arguments(const char* arg_name, - std::vector& vals, - char sep = ',') const { - using namespace std; - - if (check_cmd_line_flag(arg_name)) { - // Clear any default values - vals.clear(); - - // Recover from multi-value string - for (size_t i = 0; i < keys.size(); ++i) { - if (keys[i] == string(arg_name)) { - string val_string(values[i]); - separate_string(val_string, vals, sep); - } - } - } - } - - /** - * Returns the values specified for a given commandline parameter - * --=,* - */ - void get_cmd_line_argument_pairs(const char* arg_name, - std::vector >& tokens, - char delim = ',', - char sep = ':') const { - if (check_cmd_line_flag(arg_name)) { - std::string value; - get_cmd_line_argument(arg_name, value); - - tokenize(tokens, value, delim, sep); - } - } - - /** - * Returns a list of ranges specified for a given commandline parameter - * --=,* - */ - void get_cmd_line_argument_ranges(const char* arg_name, - std::vector >& vals, - char delim = ',', - char sep = ':') const { - std::vector ranges; - get_cmd_line_arguments(arg_name, ranges, delim); - - for (std::vector::const_iterator range = ranges.begin(); - range != ranges.end(); ++range) { - - std::vector range_vals; - separate_string(*range, range_vals, sep); - vals.push_back(range_vals); - } - } - - /** - * The number of pairs parsed - */ - int parsed_argc() const { return (int)keys.size(); } - - //------------------------------------------------------------------------- - // Utility functions - //------------------------------------------------------------------------- - - /// Tokenizes a comma-delimited list of string pairs delimited by ':' - static void tokenize(std::vector >& tokens, - std::string const& str, - char delim = ',', - char sep = ':') { - // Home-built to avoid Boost dependency - size_t s_idx = 0; - size_t d_idx = std::string::npos; - while (s_idx < str.size()) { - d_idx = str.find_first_of(delim, s_idx); - - size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); - size_t sep_idx = str.find_first_of(sep, s_idx); - size_t offset = 1; - if (sep_idx == std::string::npos || sep_idx >= end_idx) { - sep_idx = end_idx; - offset = 0; - } - - std::pair item( - str.substr(s_idx, sep_idx - s_idx), - str.substr(sep_idx + offset, end_idx - sep_idx - offset)); - - tokens.push_back(item); - s_idx = end_idx + 1; - } - } - - /// Tokenizes a comma-delimited list of string pairs delimited by ':' - static void tokenize(std::vector& tokens, - std::string const& str, - char delim = ',', - char sep = ':') { - typedef std::vector > TokenVector; - typedef TokenVector::const_iterator token_iterator; - - std::vector > token_pairs; - tokenize(token_pairs, str, delim, sep); - for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { - tokens.push_back(tok->first); - } - } - - template - static void separate_string(std::string const& str, - std::vector& vals, - char sep = ',') { - std::istringstream str_stream(str); - std::string::size_type old_pos = 0; - std::string::size_type new_pos = 0; - - // Iterate -delimited values - value_t val; - while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { - if (new_pos != old_pos) { - str_stream.width(new_pos - old_pos); - str_stream >> val; - vals.push_back(val); - } - - // skip over delimiter - str_stream.ignore(1); - old_pos = new_pos + 1; - } - - // Read last value - str_stream >> val; - vals.push_back(val); - } -}; - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp deleted file mode 100644 index 8ace1e0a232ea7cccbb2089ec8432783c49410dd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp +++ /dev/null @@ -1,528 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include - -//-- BLAM_DEBUG_OUT --------------------------------------------------------- -#ifdef BLAM_DEBUG -# include -# ifndef BLAM_DEBUG_OUT -# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl -# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl -# endif // BLAM_DEBUG_OUT -#else -# ifndef BLAM_DEBUG_OUT -# define BLAM_DEBUG_OUT(msg) -# define BLAM_DEBUG_OUT_2(msg) -# endif // BLAM_DEBUG_OUT -#endif // BLAM_DEBUG - -// User could potentially define ComplexFloat/ComplexDouble instead of std:: -#ifndef BLAM_COMPLEX_TYPES -#define BLAM_COMPLEX_TYPES 1 -#include "cutlass/cutlass.h" -#include CUDA_STD_HEADER(complex) - -namespace blam { -template -using Complex = cuda::std::complex; -using ComplexFloat = cuda::std::complex; -using ComplexDouble = cuda::std::complex; -} -#endif // BLAM_COMPLEX_TYPES - -// User could potentially define Half instead of cute:: -#ifndef BLAM_HALF_TYPE -#define BLAM_HALF_TYPE 1 -#include -namespace blam { -using Half = cute::half_t; -} -#endif // BLAM_HALF_TYPE - -namespace blam -{ -namespace cublas -{ - -inline const char* -cublas_get_error(cublasStatus_t status) -{ - switch (status) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; - default: - return "CUBLAS_ERROR -- "; - } -} - -inline bool -cublas_is_error(cublasStatus_t status) -{ - return status != CUBLAS_STATUS_SUCCESS; -} - - -// hgemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const Half* alpha, - const Half* A, int ldA, - const Half* B, int ldB, - const Half* beta, - Half* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasHgemm"); - - return cublasGemmEx(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), CUDA_R_16F, ldA, - reinterpret_cast(B), CUDA_R_16F, ldB, - reinterpret_cast(beta), - reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, - CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -// mixed hf gemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const float* alpha, - const Half* A, int ldA, - const Half* B, int ldB, - const float* beta, - float* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); - - return cublasGemmEx(handle, transA, transB, - m, n, k, - alpha, - reinterpret_cast(A), CUDA_R_16F, ldA, - reinterpret_cast(B), CUDA_R_16F, ldB, - beta, - C, CUDA_R_32F, ldC, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -// igemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const int32_t* alpha, - const int8_t* A, int ldA, - const int8_t* B, int ldB, - const int32_t* beta, - int32_t* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasIgemm"); - - return cublasGemmEx(handle, transA, transB, - m, n, k, - alpha, - A, CUDA_R_8I, ldA, - B, CUDA_R_8I, ldB, - beta, - C, CUDA_R_32I, ldC, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -// sgemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const float* alpha, - const float* A, int ldA, - const float* B, int ldB, - const float* beta, - float* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasSgemm"); - - return cublasSgemm(handle, transA, transB, - m, n, k, - alpha, - A, ldA, - B, ldB, - beta, - C, ldC); -} - -// dgemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const double* alpha, - const double* A, int ldA, - const double* B, int ldB, - const double* beta, - double* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasDgemm"); - - return cublasDgemm(handle, transA, transB, - m, n, k, - alpha, - A, ldA, - B, ldB, - beta, - C, ldC); -} - -// cgemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexFloat* alpha, - const ComplexFloat* A, int ldA, - const ComplexFloat* B, int ldB, - const ComplexFloat* beta, - ComplexFloat* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasCgemm"); - - return cublasCgemm(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), ldA, - reinterpret_cast(B), ldB, - reinterpret_cast(beta), - reinterpret_cast(C), ldC); -} - -// zgemm -inline cublasStatus_t -gemm(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexDouble* alpha, - const ComplexDouble* A, int ldA, - const ComplexDouble* B, int ldB, - const ComplexDouble* beta, - ComplexDouble* C, int ldC) -{ - BLAM_DEBUG_OUT("cublasZgemm"); - - return cublasZgemm(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), ldA, - reinterpret_cast(B), ldB, - reinterpret_cast(beta), - reinterpret_cast(C), ldC); -} - -// hgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const Half* alpha, - const Half* A, int ldA, int loA, - const Half* B, int ldB, int loB, - const Half* beta, - Half* C, int ldC, int loC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); - - return cublasHgemmStridedBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), ldA, loA, - reinterpret_cast(B), ldB, loB, - reinterpret_cast(beta), - reinterpret_cast<__half*>(C), ldC, loC, - batch_size); -} - -// sgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const float* alpha, - const float* A, int ldA, int loA, - const float* B, int ldB, int loB, - const float* beta, - float* C, int ldC, int loC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); - - return cublasSgemmStridedBatched(handle, transA, transB, - m, n, k, - alpha, - A, ldA, loA, - B, ldB, loB, - beta, - C, ldC, loC, - batch_size); -} - -// dgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const double* alpha, - const double* A, int ldA, int loA, - const double* B, int ldB, int loB, - const double* beta, - double* C, int ldC, int loC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); - - return cublasDgemmStridedBatched(handle, transA, transB, - m, n, k, - alpha, - A, ldA, loA, - B, ldB, loB, - beta, - C, ldC, loC, - batch_size); -} - -// cgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexFloat* alpha, - const ComplexFloat* A, int ldA, int loA, - const ComplexFloat* B, int ldB, int loB, - const ComplexFloat* beta, - ComplexFloat* C, int ldC, int loC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); - - return cublasCgemmStridedBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), ldA, loA, - reinterpret_cast(B), ldB, loB, - reinterpret_cast(beta), - reinterpret_cast(C), ldC, loC, - batch_size); -} - -// zgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexDouble* alpha, - const ComplexDouble* A, int ldA, int loA, - const ComplexDouble* B, int ldB, int loB, - const ComplexDouble* beta, - ComplexDouble* C, int ldC, int loC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); - - return cublasZgemmStridedBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), ldA, loA, - reinterpret_cast(B), ldB, loB, - reinterpret_cast(beta), - reinterpret_cast(C), ldC, loC, - batch_size); -} - -// hgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const Half* alpha, - const Half* const A[], int ldA, - const Half* const B[], int ldB, - const Half* beta, - Half* const C[], int ldC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasHgemmBatched"); - - return cublasHgemmBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(const_cast(A)), ldA, - // A, ldA, // cuBLAS 9.2 - reinterpret_cast(const_cast(B)), ldB, - // B, ldB, // cuBLAS 9.2 - reinterpret_cast(beta), - reinterpret_cast<__half**>(const_cast(C)), ldC, - // C, ldC, // cuBLAS 9.2 - batch_size); -} - -// sgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const float* alpha, - const float* const A[], int ldA, - const float* const B[], int ldB, - const float* beta, - float* const C[], int ldC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasSgemmBatched"); - - return cublasSgemmBatched(handle, transA, transB, - m, n, k, - alpha, - const_cast(A), ldA, - // A, ldA, // cuBLAS 9.2 - const_cast(B), ldB, - // B, ldB, // cuBLAS 9.2 - beta, - const_cast(C), ldC, - // C, ldC, // cuBLAS 9.2 - batch_size); -} - -// dgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const double* alpha, - const double* const A[], int ldA, - const double* const B[], int ldB, - const double* beta, - double* const C[], int ldC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasDgemmBatched"); - - return cublasDgemmBatched(handle, transA, transB, - m, n, k, - alpha, - const_cast(A), ldA, - // A, ldA, // cuBLAS 9.2 - const_cast(B), ldB, - // B, ldB, // cuBLAS 9.2 - beta, - const_cast(C), ldC, - // C, ldC, // cuBLAS 9.2 - batch_size); -} - -// cgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexFloat* alpha, - const ComplexFloat* const A[], int ldA, - const ComplexFloat* const B[], int ldB, - const ComplexFloat* beta, - ComplexFloat* const C[], int ldC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasCgemmBatched"); - - return cublasCgemmBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - const_cast(reinterpret_cast(A)), ldA, - //reinterpret_cast(A), ldA, // cuBLAS 9.2 - const_cast(reinterpret_cast(B)), ldB, - //reinterpret_cast(B), ldB, // cuBLAS 9.2 - reinterpret_cast(beta), - const_cast(reinterpret_cast(C)), ldC, - //reinterpret_cast(C), ldC, // cuBLAS 9.2 - batch_size); -} - -// zgemm -inline cublasStatus_t -gemm_batch(cublasHandle_t handle, - cublasOperation_t transA, cublasOperation_t transB, - int m, int n, int k, - const ComplexDouble* alpha, - const ComplexDouble* const A[], int ldA, - const ComplexDouble* const B[], int ldB, - const ComplexDouble* beta, - ComplexDouble* const C[], int ldC, - int batch_size) -{ - BLAM_DEBUG_OUT("cublasZgemmBatched"); - - return cublasZgemmBatched(handle, transA, transB, - m, n, k, - reinterpret_cast(alpha), - const_cast(reinterpret_cast(A)), ldA, - //reinterpret_cast(A), ldA, // cuBLAS 9.2 - const_cast(reinterpret_cast(B)), ldB, - //reinterpret_cast(B), ldB, // cuBLAS 9.2 - reinterpret_cast(beta), - const_cast(reinterpret_cast(C)), ldC, - //reinterpret_cast(C), ldC, // cuBLAS 9.2 - batch_size); -} - -} // end namespace cublas -} // end namespace blam diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h deleted file mode 100644 index 88481a82e0e08f06b54c07c946d28160d41f9f07..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h +++ /dev/null @@ -1,143 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Contains code for debugging cutlass code -*/ - -#pragma once - -#include "device_dump.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/****************************************************************************** - * Debug and logging macros - ******************************************************************************/ - -/** - * Formats and prints the given message to stdout - */ -#if !defined(CUDA_LOG) -#if !defined(__CUDA_ARCH__) -#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) -#else -#define CUDA_LOG(format, ...) \ - printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ - blockIdx.x, \ - blockIdx.y, \ - blockIdx.z, \ - threadIdx.x, \ - threadIdx.y, \ - threadIdx.z, \ - __VA_ARGS__); -#endif -#endif - -/** - * Formats and prints the given message to stdout only if DEBUG is defined - */ -#if !defined(CUDA_LOG_DEBUG) -#ifdef DEBUG -#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) -#else -#define CUDA_LOG_DEBUG(format, ...) -#endif -#endif - -/** - * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) - * along with the supplied source context. - * - * \return The CUDA error. - */ -__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, - const char* expression, - const char* filename, - int line) { - (void)filename; - (void)line; - if (error) { -#if !defined(__CUDA_ARCH__) - fprintf( - stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); - fflush(stderr); -#else - printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); -#endif - } - return error; -} - -/** - * \brief Perror macro - */ -#ifndef CUDA_PERROR -#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) -#endif - -/** - * \brief Perror macro with exit - */ -#ifndef CUDA_PERROR_EXIT -#define CUDA_PERROR_EXIT(e) \ - do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ - exit(1); \ - } } while (0) -#endif - -/** - * \brief Perror macro only if DEBUG is defined - */ -#ifndef CUDA_PERROR_DEBUG -#ifdef DEBUG -#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) -#else -#define CUDA_PERROR_DEBUG(e) (e) -#endif -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// A small helper class to dump a type at compile time -// Usage:: DumpType::Class -template -struct DebugType {}; - -template -void DebugTypeFunc(T const& t) { - T::t; -} - -// A small helper class to dump a compile time constant at compile time -// Usage: DumpValue::kConstant -template -struct DebugValue {}; diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h deleted file mode 100644 index a73a8cfe79dd22c2d298fcb3be8cf25d5e3f5734..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h +++ /dev/null @@ -1,187 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include "cutlass/cutlass.h" - -/** - * \file - * \brief C++ interface to dump fragments and shared memory contents for - * debugging. - */ - -namespace cutlass { -namespace debug { - -/****************************************************************************** - * Dump the fragments - ******************************************************************************/ - -/// The first N threads dump the first M elements from their fragments with a -/// stride of S elements. If N is not specified, dump the data of all the -/// threads. If M is not specified, dump all the elements of the fragment. -template -CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, - int S = 1) { - int total_threads = blockDim.x * blockDim.y * blockDim.z; - int block_id = - blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; - int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + - (threadIdx.y * blockDim.x) + threadIdx.x; - - if (N < 0 || N > total_threads) { - if (thread_id == 0 && block_id == 0) - printf("Thread number N = %d should between [1, %d].\n", N, - total_threads); - - __syncthreads(); - - return; - } - - int total_elements = int(frag.size()); - - if (M < 0 || M > total_elements) { - if (thread_id == 0 && block_id == 0) - printf("Element number M = %d should between [1, %d].\n", M, - total_elements); - - __syncthreads(); - - return; - } - - if (N == 0) N = total_threads; - - if (M == 0) M = total_elements; - - if (S < 1 || S > M) { - if (thread_id == 0 && block_id == 0) - printf("Stride S = %d should between [1, %d].\n", S, M); - - __syncthreads(); - - return; - } - - if (thread_id == 0 && block_id == 0) - printf("\n*******************Dumping the fragments*******************\n\n"); - - CUTLASS_PRAGMA_NO_UNROLL - for (int tid = 0; tid < N; ++tid) { - if (tid == thread_id) { - printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); - CUTLASS_PRAGMA_NO_UNROLL - for (int i = 0; i < M; i += S) { - printf("%.0f ", float(typename Fragment::value_type(frag[i]))); - } - printf("\n"); - } - - __syncthreads(); - } - - if (thread_id == 0 && block_id == 0) - printf("\n***********************************************************\n\n"); - - __syncthreads(); - - return; -} - -/****************************************************************************** - * Dump the shared memory - ******************************************************************************/ - -#define SHMEM_ROW_SIZE 128 - -/// Dump the shared memory contents. ptr is the begin address, size specifies -/// the number of elements that need to be dumped, and S specifies the stride. -template -CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { - int block_id = - blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; - int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + - (threadIdx.y * blockDim.x) + threadIdx.x; - - if (ptr == nullptr) { - if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); - - __syncthreads(); - return; - } - - if (size < 1) { - if (thread_id == 0 && block_id == 0) - printf("Element size is less than 1\n"); - - __syncthreads(); - - return; - } - - int row_elements = SHMEM_ROW_SIZE / sizeof(Element); - - if (S < 1 || S > row_elements) { - if (thread_id == 0 && block_id == 0) - printf("Stride S = %d should between [1, %d].\n", S, row_elements); - - __syncthreads(); - - return; - } - - __syncthreads(); - - if (thread_id == 0) - printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); - - if (thread_id == 0) { - for (int i = 0; i < size; i += row_elements) { - for (int j = 0; j < row_elements; j += S) { - printf("%.0f ", float(ptr[i + j])); - } - - printf("\n"); - } - } - - if (thread_id == 0) - printf("\n***********************************************************\n\n"); - - __syncthreads(); - - return; -} -} // namespace debug -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h deleted file mode 100644 index 59457b2e8122f46e443844fe276b2c7fb35f3f56..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h +++ /dev/null @@ -1,402 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" -#include "device_utils.h" -#include - -namespace cutlass { - -/** \brief interface to do group norm on a device memory tensor with NHWC layout. - * \tparam T: data type - */ -template -void groupnorm(cutlass::Tensor4DCoord input_size, - const int num_groups, - const float eps, - TensorRef ref_output, - TensorRef ref_input, - TensorRef ref_gamma, - TensorRef ref_beta, - cudaStream_t stream); - -extern __shared__ char groupnorm_shm[]; - -// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, -// we store the input in the shared memory. -// grid(num_groups, dim0) -// block(BLOCKSIZE) -// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -template -__global__ void groupnorm_twopass_store_locally(T* output, - const T* input, - const T* gamma, - const T* beta, - int num_groups, - int prod_dim1_to_last_dim, - int last_dim, - const float eps, - const int TVecs_PER_THREAD) -{ - const int bid = blockIdx.y; // index of batch - const int gid = blockIdx.x; // index of group - const int tid = threadIdx.x; // index of thread - const int bdimx = blockDim.x; - const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; - const int v_reduce_elements = s_reduce_elements / T_PER_TVec; - const int s_group_stride = last_dim / num_groups; - const int v_group_stride = s_group_stride / T_PER_TVec; - const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; - const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; - TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; - T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; - float local_sum[1] = {0.0f}; - -// load from global memory into shared memory -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - const int offset_in_group = - ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) - / T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - TVec tmp_vec = input_TVec_ptr[offset_in_group]; - T* tmp_vec_ptr = (T*)(&tmp_vec); - const int local_val_offset = i * T_PER_TVec; -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = static_cast(tmp_vec_ptr[j]); - local_sum[0] += tmp; - local_val[local_val_offset + j] = tmp_vec_ptr[j]; - } - } - } - __shared__ float s_mean, s_variance; - - // reduction for mean - if (bdimx <= 32) { - warpReduceSum(local_sum); - } - else { - blockReduceSum(local_sum); - } - if (tid == 0) { - s_mean = local_sum[0] / s_reduce_elements; - } - __syncthreads(); - - // reduction for std - local_sum[0] = 0.0f; -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - const int local_val_offset = i * T_PER_TVec; -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = static_cast(local_val[local_val_offset + j]); - tmp -= s_mean; - local_sum[0] += tmp * tmp; - } - } - } - if (bdimx <= 32) { - warpReduceSum(local_sum); - } - else { - blockReduceSum(local_sum); - } - if (tid == 0) { - s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); - } - __syncthreads(); - - // normalize - const int gamma_offset_of_group = gid * v_group_stride; - const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; - const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - const int offset_in_group = - ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) - / T_PER_TVec; - const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; - const int local_val_offset = i * T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; - TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; - T* gamma_val_ptr = (T*)(&gamma_val); - T* beta_val_ptr = (T*)(&beta_val); - TVec tmp_vec; - T* tmp_vec_ptr = (T*)(&tmp_vec); -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance - * static_cast(gamma_val_ptr[j]) - + static_cast(beta_val_ptr[j]); - if (sizeof(T) == sizeof(half)) { - tmp_vec_ptr[j] = T(__float2half_rn(tmp)); - } - else { - tmp_vec_ptr[j] = T(tmp); - } - } - output_TVec_ptr[offset_in_group] = tmp_vec; - } - } -} - -// For large prod_dim1_to_last_dim/num_groups, -// in which the data cannot be stored locally, -// we will load from global memory multiple times, -// grid(num_groups, dim0) -// block(BLOCKSIZE) -// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -template -__global__ void groupnorm_twopass_multiple_load(T* output, - const T* input, - const T* gamma, - const T* beta, - int num_groups, - int prod_dim1_to_last_dim, - int last_dim, - const float eps, - const int TVecs_PER_THREAD) -{ - const int bid = blockIdx.y; // index of batch - const int gid = blockIdx.x; // index of group - const int tid = threadIdx.x; // index of thread - const int bdimx = blockDim.x; - const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; - const int v_reduce_elements = s_reduce_elements / T_PER_TVec; - const int s_group_stride = last_dim / num_groups; - const int v_group_stride = s_group_stride / T_PER_TVec; - const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; - const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; - TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; - float local_sum[1] = {0.0f}; - -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - const int offset_in_group = - ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) - / T_PER_TVec; - TVec tmp_vec = input_TVec_ptr[offset_in_group]; - T* tmp_vec_ptr = (T*)(&tmp_vec); -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = static_cast(tmp_vec_ptr[j]); - local_sum[0] += tmp; - } - } - } - __shared__ float s_mean, s_variance; - - // reduction for mean - if (bdimx <= 32) { - warpReduceSum(local_sum); - } - else { - blockReduceSum(local_sum); - } - if (tid == 0) { - s_mean = local_sum[0] / s_reduce_elements; - } - __syncthreads(); - - // reduction for std - local_sum[0] = 0.0f; -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - const int offset_in_group = - ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) - / T_PER_TVec; - TVec tmp_vec = input_TVec_ptr[offset_in_group]; - T* tmp_vec_ptr = (T*)(&tmp_vec); -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = static_cast(tmp_vec_ptr[j]); - tmp -= s_mean; - local_sum[0] += tmp * tmp; - } - } - } - if (bdimx <= 32) { - warpReduceSum(local_sum); - } - else { - blockReduceSum(local_sum); - } - if (tid == 0) { - s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); - } - __syncthreads(); - - // normalize - const int gamma_offset_of_group = gid * v_group_stride; - const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; - const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -#pragma unroll - for (int i = 0; i < TVecs_PER_THREAD; i += 1) { - const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; - if (current_load_start_idx < s_reduce_elements) { - const int offset_in_group = - ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) - / T_PER_TVec; - const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; - TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; - TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; - T* gamma_val_ptr = (T*)(&gamma_val); - T* beta_val_ptr = (T*)(&beta_val); - TVec tmp_vec = input_TVec_ptr[offset_in_group]; - T* tmp_vec_ptr = (T*)(&tmp_vec); - TVec output_tmp_vec; - T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); -#pragma unroll - for (int j = 0; j < T_PER_TVec; j++) { - float tmp = - (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) - + static_cast(beta_val_ptr[j]); - if (sizeof(T) == sizeof(half)) { - output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); - } - else { - output_tmp_vec_ptr[j] = T(tmp); - } - } - output_TVec_ptr[offset_in_group] = output_tmp_vec; - } - } -} - -//ref_input & ref_output should be [N, H, W, C] -//ref_gamma & ref_beta should be [1, 1, 1, C] -template -void groupnorm(cutlass::Tensor4DCoord input_size, - const int num_groups, - const float eps, - TensorRef ref_output, - TensorRef ref_input, - TensorRef ref_gamma, - TensorRef ref_beta, - cudaStream_t stream){ - const int N = input_size.n(); - const int H = input_size.h(); - const int W = input_size.w(); - const int C = input_size.c(); - if (C % num_groups != 0){ - printf("[ERROR] C should be a multiple of num_groups.\n"); - } - T* output = ref_output.data(); - const T* input = ref_input.data(); - const T* gamma = ref_gamma.data(); - const T* beta = ref_beta.data(); - - const int dim0 = N; - const int last_dim = C; - const int prod_dim1_to_last_dim = H*W*C; - const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; - const int s_group_stride = last_dim / num_groups; - dim3 grid(num_groups, dim0); - int threadblock_size = 32; - if (s_group_stride % 2 == 0) { - const int T_PER_TVec = 2; - while (threadblock_size < 1024) { - if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) - break; - threadblock_size *= 2; - } - dim3 block(threadblock_size); - const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; - const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); - // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; - // the size of grid & block may have better choice for different cases. - // ensure shared memory is smaller than 48KB - if (std::is_same::value){ - if (shm_size < 48 * 1024) { - groupnorm_twopass_store_locally<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - else { - groupnorm_twopass_multiple_load<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - } - else{ - if (shm_size < 48 * 1024) { - groupnorm_twopass_store_locally<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - else { - groupnorm_twopass_multiple_load<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - } - } - else { - const int T_PER_TVec = 1; - while (threadblock_size < 1024) { - if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) - break; - threadblock_size *= 2; - } - dim3 block(threadblock_size); - const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; - const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); - if (shm_size < 48 * 1024) { - groupnorm_twopass_store_locally<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - else { - groupnorm_twopass_multiple_load<<>>( - output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); - } - } - -} - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h deleted file mode 100644 index 0fcbf5cb0f4bf3152a708c6e3845e89fd214cfac..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h +++ /dev/null @@ -1,644 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" -#include "device_utils.h" -#include - -namespace cutlass { - -/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. - * \tparam T: data type - */ -template -void layernorm(cutlass::MatrixCoord tensor_size, - TensorRef ref_output, - TensorRef ref_input, - TensorRef ref_gamma, - TensorRef ref_beta, - cudaStream_t stream); - -/** - * output [m, n] row-major - * input [m, n] row-major - * gamma [n] - * beta [n] - * grid(m) - * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -*/ -template -__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, - const T* input, - const T* gamma, - const T* beta, - const int m, - const int n) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean, s_variance; - T local_val[ITEM_PER_THREAD]; - float local_sums[1] = {0.0f}; - int offset = m_idx * n; - input += offset; - output += offset; - - const T zero = T(0.0f); - #pragma unroll - for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ - int index = tid + i*bdimx; - local_val[i] = index < n ? input[index] : zero; - local_sums[0] += static_cast(local_val[i]); - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = local_sums[0] / n; - } - __syncthreads(); - - local_sums[0] = 0.0f; - #pragma unroll - for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ - int index = tid + i*bdimx; - if (index < n){ - const float tmp = static_cast(local_val[i]) - s_mean; - local_sums[0] += tmp * tmp; - } - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_variance = rsqrtf(local_sums[0] / n + 1e-5); - } - __syncthreads(); - - #pragma unroll - for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ - int index = tid + i*bdimx; - if (index < n) { - const T gamma_val = gamma[index]; - const T beta_val = beta[index]; - output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); - } - } -} - -/** - * output [m, n] row-major - * input [m, n] row-major - * gamma [n] - * beta [n] - * grid(m) - * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -*/ -template -__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, - const T2* input, - const T2* gamma, - const T2* beta, - const int m, - const int n) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean, s_variance; - float local_sums[1] = {0.0f}; - T2 local_val[ITEM_PER_THREAD]; - const int n_2 = n / 2; - int offset = m_idx * n_2; - input += offset; - output += offset; - - const T2 zero = {T(0.0f), T(0.0f)}; - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - local_val[i] = index < n_2 ? input[index] : zero; - local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = local_sums[0] / n; - } - __syncthreads(); - - local_sums[0] = 0.0f; - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - if (index < n_2){ - const float2 tmp = {static_cast(local_val[i].x) - s_mean, - static_cast(local_val[i].y) - s_mean}; - local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; - } - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_variance = rsqrtf(local_sums[0] / n + 1e-5); - } - __syncthreads(); - - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - if (index < n_2){ - const T2 gamma_val = gamma[index]; - const T2 beta_val = beta[index]; - T2 tmp; - tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); - tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); - output[index] = tmp; - } - } -} - -/** - * output [m, n] row-major - * input [m, n] row-major - * gamma [n] - * beta [n] - * grid(m) - * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; -*/ -template -__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, - const T4* input, - const T4* gamma, - const T4* beta, - const int m, - const int n) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean, s_variance; - float local_sums[1] = {0.0f}; - T4 local_val[ITEM_PER_THREAD]; - const int n_4 = n / 4; - int offset = m_idx * n_4; - input += offset; - output += offset; - - const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - local_val[i] = index < n_4 ? input[index] : zero; - local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + - static_cast(local_val[i].z) + static_cast(local_val[i].w); - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = local_sums[0] / n; - } - __syncthreads(); - - local_sums[0] = 0.0f; - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - if (index < n_4){ - const float4 tmp = {static_cast(local_val[i].x) - s_mean, - static_cast(local_val[i].y) - s_mean, - static_cast(local_val[i].z) - s_mean, - static_cast(local_val[i].w) - s_mean}; - local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; - } - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_variance = rsqrtf(local_sums[0] / n + 1e-5); - } - __syncthreads(); - - #pragma UNROLL - for (int i = 0; i < ITEM_PER_THREAD; i += 1) { - const int index = i*bdimx + tid; - if (index < n_4){ - const T4 gamma_val = gamma[index]; - const T4 beta_val = beta[index]; - T4 tmp; - tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); - tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); - tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); - tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); - output[index] = tmp; - } - } -} - -/** - * output [m, n] row-major - * input [m, n] row-major - * gamma [n] - * beta [n] - * grid(m) - * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -*/ -template -__global__ void layernorm_twoPassAlgo_e1(T* output, - const T* input, - const T* gamma, - const T* beta, - const int m, - const int n) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean, s_variance; - float local_sums[1] = {0.0f}; - int offset = m_idx * n; - input += offset; - output += offset; - - for (int index = tid ; index < n ; index += bdimx){ - float local_val = static_cast(input[index]); - local_sums[0] += local_val; - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = local_sums[0] / n; - } - __syncthreads(); - - local_sums[0] = 0.0f; - for (int index = tid ; index < n ; index += bdimx){ - float local_val = static_cast(input[index]); - local_val = local_val - s_mean; - local_sums[0] += local_val * local_val; - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_variance = rsqrtf(local_sums[0] / n + 1e-5); - } - __syncthreads(); - - for (int index = tid ; index < n ; index += bdimx){ - const T gamma_val = gamma[index]; - const T beta_val = beta[index]; - const T local_val = input[index]; - output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); - } -} - -/** - * output [m, n] row-major - * input [m, n] row-major - * gamma [n] - * beta [n] - * grid(m) - * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -*/ -template -__global__ void layernorm_twoPassAlgo_e2(T2* output, - const T2* input, - const T2* gamma, - const T2* beta, - const int m, - const int n) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean, s_variance; - float local_sums[1] = {0.0f}; - const int n_2 = n / 2; - int offset = m_idx * n_2; - input += offset; - output += offset; - - for (int index = tid; index < n_2; index += bdimx) { - const T2 local_val = input[index]; - local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = local_sums[0] / n; - } - __syncthreads(); - - local_sums[0] = 0.0f; - for (int index = tid; index < n_2; index += bdimx) { - const T2 local_val = input[index]; - const float2 tmp = {static_cast(local_val.x) - s_mean, - static_cast(local_val.y) - s_mean}; - local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_variance = rsqrtf(local_sums[0] / n + 1e-5); - } - __syncthreads(); - - for (int index = tid; index < n_2; index += bdimx) { - const T2 local_val = input[index]; - const T2 gamma_val = gamma[index]; - const T2 beta_val = beta[index]; - T2 tmp; - tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); - tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); - output[index] = tmp; - } -} - -template -void layernorm(cutlass::MatrixCoord tensor_size, - TensorRef ref_output, - TensorRef ref_input, - TensorRef ref_gamma, - TensorRef ref_beta, - cudaStream_t stream){ - const int m = tensor_size.row(); - const int n = tensor_size.column(); - T* output = ref_output.data(); - const T* input = ref_input.data(); - const T* gamma = ref_gamma.data(); - const T* beta = ref_beta.data(); - dim3 grid(m); - dim3 block((n + 31)/32*32); - if (block.x > 1024){ - block.x = 1024; - } - // TODO : There should be better configs for different cases, we only use several samples to show how to use here - // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. - if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { - block.x = (n/4 + 31)/32*32; - if (std::is_same::value) { - layernorm_twoPassAlgo_stored_locally_e4<<>>( - (float4*)output, - (const float4*)input, - (const float4*)gamma, - (const float4*)beta, - m, - n); - } // if (std::is_same::value) - else { - layernorm_twoPassAlgo_stored_locally_e4<<>>( - (half4*)output, - (const half4*)input, - (const half4*)gamma, - (const half4*)beta, - m, - n); - } - } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) - else if (n % 2 == 0) { - if (n / 2 <= 1024) { - block.x = (n/2 + 31)/32*32; - if (std::is_same::value) { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (float2*)output, - (const float2*)input, - (const float2*)gamma, - (const float2*)beta, - m, - n); - } //if (std::is_same::value) - else { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (half2*)output, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - m, - n); - } - } // if (n / 2 <= 1024) - else if (n <= 8192) { - block.x = ((n + 7)/8 + 31)/32*32; - if (std::is_same::value) { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (float2*)output, - (const float2*)input, - (const float2*)gamma, - (const float2*)beta, - m, - n); - } // if (std::is_same::value) - else { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (half2*)output, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - m, - n); - } - } // if (n <= 8192) - else if (n <= 16384) { - block.x = ((n + 15)/ 16 + 31)/32*32; - if (std::is_same::value) { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (float2*)output, - (const float2*)input, - (const float2*)gamma, - (const float2*)beta, - m, - n); - } // if (std::is_same::value) - else { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (half2*)output, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - m, - n); - } - } // if (n <= 16384) - else if (n <= 32768) { - block.x = ((n + 31)/32 + 31)/32*32; - if (std::is_same::value) { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (float2*)output, - (const float2*)input, - (const float2*)gamma, - (const float2*)beta, - m, - n); - } // if (std::is_same::value) - else { - layernorm_twoPassAlgo_stored_locally_e2<<>>( - (half2*)output, - (const half2*)input, - (const half2*)gamma, - (const half2*)beta, - m, - n); - } - } // if (n <= 32768) - else { - if (block.x > 512) - block.x = 512; - if (std::is_same::value) { - layernorm_twoPassAlgo_e2<<>>( - (float2 *)output, - (const float2 *)input, - (const float2 *)gamma, - (const float2 *)beta, - m, - n); - } // if (std::is_same::value) - else { - layernorm_twoPassAlgo_e2<<>>( - (half2 *)output, - (const half2 *)input, - (const half2 *)gamma, - (const half2 *)beta, - m, - n); - } - } - } // if (n % 2 == 0) - else { - if (n <= 1024) { - layernorm_twoPassAlgo_stored_locally_e1<<>>( - output, - input, - gamma, - beta, - m, - n); - } // if (n <= 1024) - else if (n <= 8192) { - block.x = ((n + 7)/8 + 31)/32*32; - layernorm_twoPassAlgo_stored_locally_e1<<>>( - output, - input, - gamma, - beta, - m, - n); - } // if (n <= 8192) - else if (n <= 16384) { - block.x = ((n + 15)/16 + 32)/32*32; - layernorm_twoPassAlgo_stored_locally_e1<<>>( - output, - input, - gamma, - beta, - m, - n); - } // if (n <= 16384) - else if (n <= 32768) { - block.x = ((n + 31)/32 + 31)/32*32; - layernorm_twoPassAlgo_stored_locally_e1<<>>( - output, - input, - gamma, - beta, - m, - n); - } // if (n <= 32768) - else{ - if (block.x > 512) { - block.x = 512; - } - layernorm_twoPassAlgo_e1<<>>( - output, - input, - gamma, - beta, - m, - n); - } - } -} - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h deleted file mode 100644 index 44f6a467a5d0938289e4bc127cddc13b9aeabdf3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h +++ /dev/null @@ -1,375 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief C++ interface to CUDA device memory management functions. - */ - -#include -#include - -#include "cutlass/platform/platform.h" -#include "cutlass/numeric_types.h" -#include "cutlass/trace.h" -#include "exceptions.h" - -namespace cutlass { -namespace device_memory { - -/****************************************************************************** - * Allocation lifetime - ******************************************************************************/ - -/// Allocate a buffer of \p count elements of type \p T on the current CUDA device -template -T* allocate(size_t count = 1) { - - T* ptr = 0; - size_t bytes = count * sizeof_bits::value / 8; - - cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); - - if (cuda_error != cudaSuccess) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) - std::ostringstream os; - os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; - CUTLASS_TRACE_HOST(os.str()); -#endif - throw cuda_exception("Failed to allocate memory", cuda_error); - } -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - else { - std::ostringstream os; - os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; - CUTLASS_TRACE_HOST(os.str()); - } -#endif - - return ptr; -} - -/// Free the buffer pointed to by \p ptr -template -void free(T* ptr) { - if (ptr) { - cudaError_t cuda_error = (cudaFree(ptr)); - if (cuda_error != cudaSuccess) { - throw cuda_exception("Failed to free device memory", cuda_error); - } - } -} - -/****************************************************************************** - * Data movement - ******************************************************************************/ - -template -void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { - size_t bytes = count * sizeof_bits::value / 8; - if (bytes == 0 && count > 0) { - bytes = 1; - } - cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); - if (cuda_error != cudaSuccess) { - std::ostringstream os; - os << "cutlass::device_memory::copy: cudaMemcpy() failed: " - << "dst=" << dst << ", src=" << src - << ", bytes=" << bytes << ", count=" << count; - if (kind == cudaMemcpyHostToDevice) { - os << ", kind=cudaMemcpyHostToDevice"; - } - else if (kind == cudaMemcpyDeviceToHost) { - os << ", kind=cudaMemcpyDeviceToHost"; - } - else if (kind == cudaMemcpyDeviceToDevice) { - os << ", kind=cudaMemcpyDeviceToDevice"; - } - else if (kind == cudaMemcpyHostToHost) { - os << ", kind=cudaMemcpyHostToHost"; - } - else if (kind == cudaMemcpyDefault) { - os << ", kind=cudaMemcpyDefault"; - } - else { - os << ", kind=Unknown"; - } - os << ", error: " << cudaGetErrorString(cuda_error); - - throw cuda_exception(os.str().c_str(), cuda_error); - } -} - -template -void copy_to_device(T* dst, T const* src, size_t count = 1) { - copy(dst, src, count, cudaMemcpyHostToDevice); -} - -template -void copy_to_host(T* dst, T const* src, size_t count = 1) { - copy(dst, src, count, cudaMemcpyDeviceToHost); -} - -template -void copy_device_to_device(T* dst, T const* src, size_t count = 1) { - copy(dst, src, count, cudaMemcpyDeviceToDevice); -} - -template -void copy_host_to_host(T* dst, T const* src, size_t count = 1) { - copy(dst, src, count, cudaMemcpyHostToHost); -} - -/// Copies elements from device memory to host-side range -template -void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { - size_t elements = end - begin; - copy_to_host(&*begin, device_begin, elements); -} - -/// Copies elements to device memory from host-side range -template -void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { - size_t elements = end - begin; - copy_to_device(device_begin, &*begin, elements); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device_memory - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class DeviceAllocation { -public: - - /// Delete functor for CUDA device memory - struct deleter { - void operator()(T* ptr) { - cudaError_t cuda_error = (cudaFree(ptr)); - if (cuda_error != cudaSuccess) { - // noexcept - // throw cuda_exception("cudaFree() failed", cuda_error); - return; - } - } - }; - -public: - // - // Data members - // - - /// Number of elements of T allocated on the current CUDA device - size_t capacity; - - /// Smart pointer - platform::unique_ptr smart_ptr; - -public: - - // - // Static methods - // - - /// Static member to compute the number of bytes needed for a given number of elements - static size_t bytes(size_t elements) { - if (sizeof_bits::value < 8) { - size_t const kElementsPerByte = 8 / sizeof_bits::value; - return elements / kElementsPerByte; - } - else { - size_t const kBytesPerElement = sizeof_bits::value / 8; - return elements * kBytesPerElement; - } - } - -public: - - // - // Methods - // - - /// Constructor: allocates no memory - DeviceAllocation() : capacity(0) {} - - /// Constructor: allocates \p capacity elements on the current CUDA device - DeviceAllocation(size_t _capacity) : - smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} - - /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation - DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} - - /// Copy constructor - DeviceAllocation(DeviceAllocation const &p): - smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { - - device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); - } - - /// Move constructor - DeviceAllocation(DeviceAllocation &&p): capacity(0) { - std::swap(smart_ptr, p.smart_ptr); - std::swap(capacity, p.capacity); - } - - /// Destructor - ~DeviceAllocation() { reset(); } - - /// Returns a pointer to the managed object - T* get() const { return smart_ptr.get(); } - - /// Releases the ownership of the managed object (without deleting) and resets capacity to zero - T* release() { - capacity = 0; - return smart_ptr.release(); - } - - /// Deletes the managed object and resets capacity to zero - void reset() { - capacity = 0; - smart_ptr.reset(); - } - - /// Deletes managed object, if owned, and allocates a new object - void reset(size_t _capacity) { - reset(device_memory::allocate(_capacity), _capacity); - } - - /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity - void reset(T* _ptr, size_t _capacity) { - smart_ptr.reset(_ptr); - capacity = _capacity; - } - - /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. - void reallocate(size_t new_capacity) { - - platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); - - device_memory::copy_device_to_device( - new_allocation.get(), - smart_ptr.get(), - std::min(new_capacity, capacity)); - - std::swap(smart_ptr, new_allocation); - std::swap(new_capacity, capacity); - } - - /// Returns the number of elements - size_t size() const { - return capacity; - } - - /// Returns the number of bytes needed to store the allocation - size_t bytes() const { - return bytes(capacity); - } - - /// Returns a pointer to the object owned by *this - T* operator->() const { return smart_ptr.get(); } - - /// Returns the deleter object which would be used for destruction of the managed object. - deleter& get_deleter() { return smart_ptr.get_deleter(); } - - /// Returns the deleter object which would be used for destruction of the managed object (const) - const deleter& get_deleter() const { return smart_ptr.get_deleter(); } - - /// Copies a device-side memory allocation - DeviceAllocation & operator=(DeviceAllocation const &p) { - if (capacity != p.capacity) { - smart_ptr.reset(device_memory::allocate(p.capacity)); - capacity = p.capacity; - } - device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); - return *this; - } - - /// Move assignment - DeviceAllocation & operator=(DeviceAllocation && p) { - std::swap(smart_ptr, p.smart_ptr); - std::swap(capacity, p.capacity); - return *this; - } - - /// Copies the entire allocation from another location in device memory. - void copy_from_device(T const *ptr) const { - copy_from_device(ptr, capacity); - } - - /// Copies a given number of elements from device memory - void copy_from_device(T const *ptr, size_t elements) const { - device_memory::copy_device_to_device(get(), ptr, elements); - } - - void copy_to_device(T *ptr) const { - copy_to_device(ptr, capacity); - } - - void copy_to_device(T *ptr, size_t elements) const { - device_memory::copy_device_to_device(ptr, get(), elements); - } - - void copy_from_host(T const *ptr) const { - copy_from_host(ptr, capacity); - } - - void copy_from_host(T const *ptr, size_t elements) const { - device_memory::copy_to_device(get(), ptr, elements); - } - - void copy_to_host(T *ptr) const { - copy_to_host(ptr, capacity); - } - - void copy_to_host(T *ptr, size_t elements) const { - device_memory::copy_to_host(ptr, get(), elements); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace device_memory { - -/// Device allocation abstraction that tracks size and capacity -template -using allocation = cutlass::DeviceAllocation; - -} // namespace device_memory - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h deleted file mode 100644 index 8e38029951d27c0be8da059b59d2a83fe2762ef1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h +++ /dev/null @@ -1,141 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" - -namespace cutlass { - -/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. - * \tparam T: data type - */ -template -void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream); - -template -__global__ void nchw_to_nhwc_kernel(T *output, - const T *input, - const int n, - const int h, - const int w, - const int c) { - const int hw = h*w; - const int chw = c*hw; - __shared__ T shbuf[32 * (32 + 1)]; - const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; - const int32_t wid = tid / 32; - const int32_t lid = tid % 32; - const int32_t ni = blockIdx.z; - const int32_t ci0 = blockIdx.y * 32; - const int32_t hwi0 = blockIdx.x * 32; - - const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; - const T *A = input + input_idx; - if (hwi0 + lid < hw) { - const int lid_x_33 = lid * 33; - if ((ci0 + 32) <= c) { - int ci = wid; // between 0 and 7 - CUTLASS_PRAGMA_UNROLL - for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { - shbuf[lid_x_33 + ci] = A[lid]; - A = &A[8 * hw]; - ci += 8; - } - } else { - for (int ci = wid; ci < 32; ci += 8) { - if ((ci + ci0) < c) { - shbuf[lid_x_33 + ci] = A[lid]; - } - A = &A[8 * hw]; - } - } - } - __syncthreads(); - - const int32_t ciOut = ci0 + lid; - output = &output[ni * chw + ciOut]; - if (ciOut < c) { - if (hwi0 + 32 < hw) { - int hwI = wid; - CUTLASS_PRAGMA_UNROLL - for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { - output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; - hwI += 8; - } - } else { - for (int hwI = wid; hwI < 32; hwI += 8) { - if (hwi0 + hwI < hw) { - output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; - } - } - } - } -} - -template -void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream) { - - assert( - input_tensor_size.n() == output_tensor_size.n() && - input_tensor_size.c() == output_tensor_size.h() && - input_tensor_size.h() == output_tensor_size.w() && - input_tensor_size.w() == output_tensor_size.c()); - - int n = output_tensor_size.n(); - int h = output_tensor_size.h(); - int w = output_tensor_size.w(); - int c = output_tensor_size.c(); - - dim3 grid((h*w + 31)/32, (c + 31)/32, n); - dim3 block(32, 8); - nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), - n, h, w, c); -} - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h deleted file mode 100644 index f58da62a35350b4a865f4521ec1cbb76ae87e874..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h +++ /dev/null @@ -1,276 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels for padding in device memory with NHWC layout. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" - -namespace cutlass { - -/** \brief interface for padding in a device memory tensor with NHWC layout - * \tparam T: data type - */ -template -void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream); - - -template -__global__ void nhwc_padding_kernel(const int32_t n, - const int32_t h, - const int32_t w, - const int32_t c_in, - const int32_t c_out, - const T zero, - const T *input, - T *output){ - - const int32_t idx_jump = blockDim.x * gridDim.x; - const int32_t total_elements = n * h * w * c_out; - - int32_t c_idx, w_idx, h_idx, n_idx, resudial; - - T value; - for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { - - c_idx = idx%c_out; - if (c_idx >= c_in){ - value = zero; - } - else{ - resudial = idx/c_out; - w_idx = resudial%w; - resudial = resudial/w; - h_idx = resudial%h; - n_idx = resudial/h; - resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; - value = input[resudial]; - } - output[idx] = value; - } -} - - -// fast kernel for c_in = 3 & c_out = 4 -template -__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, - const int32_t h, - const int32_t w, - const Tio *input, - Tio *output, - const int32_t max_output_element, - const int32_t max_input_element, - const Tio zero_io, - const Telement zero_element){ - __shared__ Tio shm[192]; - const int tidx = blockIdx.x * 192 + threadIdx.x; - const int threadidx = threadIdx.x; - - shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; - __syncthreads(); - - const int output_offset = blockIdx.x * 256; - const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256; - for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) - { - const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; - Telement array[element_in_Tio]; - CUTLASS_PRAGMA_UNROLL - for (int k = 0 ; k < element_in_Tio ; k++) - array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; - output[i] = *((const Tio *)array); - } -} - -// fast kernel for c_in = 3 & c_out = 8 -template -__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, - const int32_t h, - const int32_t w, - const Tio *input, - Tio *output, - const int32_t max_output_element, - const int32_t max_input_element, - const Tio zero_io, - const Telement zero_element){ - __shared__ Tio shm[192]; - const int tidx = blockIdx.x * 192 + threadIdx.x; - const int threadidx = threadIdx.x; - - shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; - __syncthreads(); - - const int output_offset = blockIdx.x * 512; - const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512; - for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) - { - const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; - Telement array[element_in_Tio]; - //float - if (element_in_Tio == 4){ - CUTLASS_PRAGMA_UNROLL - for (int k = 0 ; k < element_in_Tio ; k++) - array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); - } - //half - else{ - CUTLASS_PRAGMA_UNROLL - for (int k = 0 ; k < element_in_Tio ; k++) - array[k] = (k >= 3) ? zero_element : shm_element[k]; - } - output[i] = *((const Tio *)array); - } -} - -template -void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream){ - assert( - input_tensor_size.n() == output_tensor_size.n() && - input_tensor_size.h() == output_tensor_size.h() && - input_tensor_size.w() == output_tensor_size.w() && - input_tensor_size.c() <= output_tensor_size.c()); - - int n = input_tensor_size.n(); - int h = input_tensor_size.h(); - int w = input_tensor_size.w(); - int c_in = input_tensor_size.c(); - int c_out = output_tensor_size.c(); - - //case 1 : channel == 3 padding to 4 or 8 - if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ - dim3 block(192); - const int nhw = n*h*w; - const int nhwc = nhw*c_in; - //for half_t - if (cutlass::sizeof_bits::value == 16){ - const int element_in_Tio = 8; - const int max_input_element = nhwc/element_in_Tio; - const int max_output_element = nhw*c_out/element_in_Tio; - const int4 zero_io = {0, 0, 0, 0}; - const half_t zero_element = static_cast(0.0f); - dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); - if (c_out == 4){ - nhwc_padding_channel_3To4_kernel<<>> - (n, h, w, - (const int4 *)ref_input.data(), - (int4 *)ref_output.data(), - max_output_element, - max_input_element, - zero_io, - zero_element); - } - else if (c_out == 8){ - nhwc_padding_channel_3To8_kernel<<>> - (n, h, w, - (const int4 *)ref_input.data(), - (int4 *)ref_output.data(), - max_output_element, - max_input_element, - zero_io, - zero_element); - } - } - //for float - else{ - const int element_in_Tio = 4; - const int max_input_element = nhwc/element_in_Tio; - const int max_output_element = nhw*c_out/element_in_Tio; - const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; - const float zero_element = 0.0f; - dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); - if (c_out == 4){ - nhwc_padding_channel_3To4_kernel<<>> - (n, h, w, - (const float4 *)ref_input.data(), - (float4 *)ref_output.data(), - max_output_element, - max_input_element, - zero_io, - zero_element); - } - else if (c_out == 8){ - nhwc_padding_channel_3To8_kernel<<>> - (n, h, w, - (const float4 *)ref_input.data(), - (float4 *)ref_output.data(), - max_output_element, - max_input_element, - zero_io, - zero_element); - } - } - } - //case 2 : even channel - else if ((c_out % 2) == 0 && (c_in % 2) == 0){ - int32_t total_elements = n * h * w * c_out / 2; - int block_size = 256; - dim3 grid((total_elements + 255)/256); - dim3 block(block_size); - //for half_t - if (cutlass::sizeof_bits::value == 16){ - const __half2 zero = {0.0f, 0.0f}; - nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); - } - //for float - else{ - const float2 zero = {0.0f, 0.0f}; - nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); - } - } - //case 3 : odd channel - else{ - int32_t total_elements = n * h * w * c_out; - int block_size = 256; - dim3 grid((total_elements + 255)/256); - dim3 block(block_size); - const T zero = static_cast(0.0f); - nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); - } -} - - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h deleted file mode 100644 index 5633456c1412ff41366ec4c6ec5c3e6e3a2d6c19..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h +++ /dev/null @@ -1,573 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" -#include "device_utils.h" -#include - -namespace cutlass { - -/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. - * \tparam T: data type - */ -template -void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord filter_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - cutlass::MatrixCoord padding, - cutlass::MatrixCoord stride, - TensorRef ref_input, - TensorRef ref_output, - int poolingType, //0 for avg pooling ; 1 for max pooling - cudaStream_t stream); - -/** get the output size of pooling - */ -inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) -{ - return (H_W + 2 * padding - kernel_size) / stride + 1; -} - -/** - * input is [N, H, W, C] - * assume stride == kernel_size - * output_h = (H + 2*padding_H - kernel_H)/stride_H - * output_w = (W + 2*padding_W - kernel_W)/stride_W - * output is [N, output_h, output_w, C] - * grid(N, output_h, output_w) - * block(min(C, 256)) : - * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) -*/ -template -__global__ void pooling_nhwc_element1_kernel(T* output, - const T* input, - const int N, - const int H, - const int W, - const int C, - const int output_H, - const int output_W, - const int kernel_H, - const int kernel_W, - const int stride_H, - const int stride_W, - const int padding_H, - const int padding_W) -{ - const int tid = threadIdx.x; - const int n_idx = blockIdx.x; - const int output_h_idx = blockIdx.y; - const int output_w_idx = blockIdx.z; - - int h_start_idx = output_h_idx * stride_H - padding_H; - int h_end_idx = h_start_idx + kernel_H; - h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; - h_end_idx = h_end_idx > H ? H : h_end_idx; - - int w_start_idx = output_w_idx * stride_W - padding_W; - int w_end_idx = w_start_idx + kernel_W; - w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; - w_end_idx = w_end_idx > W ? W : w_end_idx; - - input += n_idx * H * W * C; - output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; - const int kernel_size2 = kernel_H * kernel_W; - for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { - float pooling; - if (IS_AVG_POOLING){ - pooling = 0.0f; - } - else{ - pooling = -FLT_MAX; - } - for (int h = h_start_idx; h < h_end_idx; h++) { - for (int w = w_start_idx; w < w_end_idx; w++) { - const int idx = (h * W + w) * C; - const float tmp = static_cast(input[idx + c_idx]); - if (IS_AVG_POOLING){ - pooling = pooling + tmp; - } - else{ - pooling = pooling > tmp ? pooling : tmp; - } - } - } - - T output_val; - if (IS_AVG_POOLING){ - output_val = T(pooling/kernel_size2); - } - else{ - output_val = T(pooling); - } - output[c_idx] = output_val; - } -} - -template -__global__ void pooling_nhwc_element2_kernel(T2* output, - const T2* input, - const int N, - const int H, - const int W, - const int C, - const int output_H, - const int output_W, - const int kernel_H, - const int kernel_W, - const int stride_H, - const int stride_W, - const int padding_H, - const int padding_W) -{ - const int tid = threadIdx.x; - const int n_idx = blockIdx.x; - const int output_h_idx = blockIdx.y; - const int output_w_idx = blockIdx.z; - - int h_start_idx = output_h_idx * stride_H - padding_H; - int h_end_idx = h_start_idx + kernel_H; - h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; - h_end_idx = h_end_idx > H ? H : h_end_idx; - - int w_start_idx = output_w_idx * stride_W - padding_W; - int w_end_idx = w_start_idx + kernel_W; - w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; - w_end_idx = w_end_idx > W ? W : w_end_idx; - - input += n_idx * H * W * C; - output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; - const int kernel_size2 = kernel_H * kernel_W; - for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { - float2 pooling; - if (IS_AVG_POOLING) { - pooling = {0.0f, 0.0f}; - } - else { - pooling = {-FLT_MAX, -FLT_MAX}; - } - for (int h = h_start_idx; h < h_end_idx; h++) { - for (int w = w_start_idx; w < w_end_idx; w++) { - const int idx = (h * W + w) * C; - const T2 tmp = input[idx + c_idx]; - const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; - if (IS_AVG_POOLING) { - pooling.x += tmp_flt2.x; - pooling.y += tmp_flt2.y; - } - else { - pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; - pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; - } - } - } - - T2 output_val; - if (IS_AVG_POOLING) { - output_val.x = T(pooling.x/kernel_size2); - output_val.y = T(pooling.y/kernel_size2); - } - else { - output_val.x = T(pooling.x); - output_val.y = T(pooling.y); - } - output[c_idx] = output_val; - } -} - -/** - * output [N, 1, 1, C] - * input [N, H, W, C] - * grid(C, N) - * block(block_size) -- each block deals with H*W/block_size elements; -*/ -template -__global__ void pooling_nxhTo1x1_element1_kernel( - T* output, const T* input, const int N, const int HW, const int C) -{ - const int c_idx = blockIdx.x; - const int n_idx = blockIdx.y; - float pooling[1]; - if (IS_AVG_POOLING) { - pooling[0] = 0.0f; - } - else { - pooling[0] = -FLT_MAX; - } - const size_t input_offset = n_idx * HW * C + c_idx; - input += input_offset; - const size_t output_offset = n_idx * C + c_idx; - output += output_offset; - int tid = threadIdx.x; - - for (int index = tid; index < HW; index += blockDim.x) { - float val = static_cast(input[index * C]); - if (IS_AVG_POOLING) { - pooling[0] += val; - } - else { - pooling[0] = pooling[0] > val ? pooling[0] : val; - } - } - if (blockDim.x <= 32) { - if (IS_AVG_POOLING) { - warpReduceSum(pooling); - } - else { - warpReduceMax(pooling); - } - } - else { - if (IS_AVG_POOLING) { - blockReduceSum(pooling); - } - else { - blockReduceMax(pooling); - } - } - __syncthreads(); - if (threadIdx.x == 0) { - T output_val; - if (IS_AVG_POOLING) { - output_val = T(pooling[0] / HW); - } - else { - output_val = T(pooling[0]); - } - output[0] = output_val; - } -} - - -/** - * output [N, 1, 1, C] - * input [N, H, W, C] - * grid(C/2, N) - * block(block_size) -- each thread deals with H*W/block_size * 2 elements; -*/ -template -__global__ void pooling_nxhTo1x1_element2_kernel( - T2* output, const T2* input, const int N, const int HW, const int C) -{ - const int c_idx = blockIdx.x; - const int n_idx = blockIdx.y; - float pooling[2]; - if (IS_AVG_POOLING) { - pooling[0] = pooling[1] = 0.0f; - } - else { - pooling[0] = pooling[1] = -FLT_MAX; - } - const int C_2 = C / 2; - const size_t input_offset = n_idx * HW * C_2 + c_idx; - input += input_offset; - const size_t output_offset = n_idx * C_2 + c_idx; - output += output_offset; - int tid = threadIdx.x; - - for (int index = tid; index < HW; index += blockDim.x) { - T2 val = input[index * C_2]; - float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; - if (IS_AVG_POOLING) { - pooling[0] += val_flt2.x; - pooling[1] += val_flt2.y; - } - else { - pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; - pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; - } - } - if (blockDim.x <= 32) { - if (IS_AVG_POOLING) { - warpReduceSum(pooling); - } - else { - warpReduceMax(pooling); - } - } - else { - if (IS_AVG_POOLING) { - blockReduceSum(pooling); - } - else { - blockReduceMax(pooling); - } - } - __syncthreads(); - if (threadIdx.x == 0) { - T2 output_val; - if (IS_AVG_POOLING) { - output_val.x = T(pooling[0] / HW); - output_val.y = T(pooling[1] / HW); - } - else { - output_val.x = T(pooling[0]); - output_val.y = T(pooling[1]); - } - output[0] = output_val; - } -} - -template -void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord filter_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - cutlass::Tensor4DCoord padding, - cutlass::MatrixCoord stride, - TensorRef ref_input, - TensorRef ref_output, - int poolingType, //0 for avg pooling ; 1 for max pooling - cudaStream_t stream) { - - assert(input_tensor_size.n() == output_tensor_size.n() && - input_tensor_size.c() == output_tensor_size.c()); - - const int N = input_tensor_size.n(); - const int H = input_tensor_size.h(); - const int W = input_tensor_size.w(); - const int C = input_tensor_size.c(); - const int padding_H = padding.h(); - const int padding_W = padding.w(); - const int kernel_H = filter_tensor_size.h(); - const int kernel_W = filter_tensor_size.w(); - const int stride_H = stride.row(); - const int stride_W = stride.column(); - - const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); - const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); - - assert(output_tensor_size.h() == output_H && - output_tensor_size.w() == output_W); - - if (C % 2 != 0) { - if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { - dim3 grid(C, N); - dim3 block(256); - if (H*W < block.x){ - block.x = (H*W + 31)/32*32; - } - if (poolingType == 0) { - pooling_nxhTo1x1_element1_kernel<<>>( - ref_output.data(), - ref_input.data(), - N, - H*W, - C); - } // if (poolingType == 0) - else { - pooling_nxhTo1x1_element1_kernel<<>>( - ref_output.data(), - ref_input.data(), - N, - H*W, - C); - } - } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) - else { - dim3 grid(N, output_H, output_W); - dim3 block(256); - if (C < block.x) { - block.x = C; - } - if (poolingType == 0) { - pooling_nhwc_element1_kernel<<>>( - ref_output.data(), - ref_input.data(), - N, - H, - W, - C, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } // if (poolingType == 0) - else { - pooling_nhwc_element1_kernel<<>>( - ref_output.data(), - ref_input.data(), - N, - H, - W, - C, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } - } - } // if (C % 2 != 0)) - else { - if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { - dim3 grid(C/2, N); - dim3 block(256); - if (H*W < block.x){ - block.x = (H*W + 31)/32*32; - } - if (poolingType == 0) { - if (std::is_same::value) { - pooling_nxhTo1x1_element2_kernel<<>>( - (float2*)(ref_output.data()), - (const float2*)(ref_input.data()), - N, - H*W, - C); - } // if (std::is_same::value) - else { - pooling_nxhTo1x1_element2_kernel<<>>( - (half2*)(ref_output.data()), - (const half2*)(ref_input.data()), - N, - H*W, - C); - } - } // if (poolingType == 0) - else { - if (std::is_same::value) { - pooling_nxhTo1x1_element2_kernel<<>>( - (float2*)(ref_output.data()), - (const float2*)(ref_input.data()), - N, - H*W, - C); - } // if (std::is_same::value) - else { - pooling_nxhTo1x1_element2_kernel<<>>( - (half2*)(ref_output.data()), - (const half2*)(ref_input.data()), - N, - H*W, - C); - } - } - } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) - else { - dim3 grid(N, output_H, output_W); - dim3 block(256); - if (C/2 < block.x) { - block.x = C/2; - } - if (poolingType == 0) { - if (std::is_same::value) { - pooling_nhwc_element2_kernel<<>>( - (float2*)(ref_output.data()), - (const float2*)(ref_input.data()), - N, - H, - W, - C/2, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } // if (std::is_same::value) - else { - pooling_nhwc_element2_kernel<<>>( - (half2*)(ref_output.data()), - (const half2*)(ref_input.data()), - N, - H, - W, - C/2, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } - } // if (poolingType == 0) - else { - if (std::is_same::value) { - pooling_nhwc_element2_kernel<<>>( - (float2*)(ref_output.data()), - (const float2*)(ref_input.data()), - N, - H, - W, - C/2, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } // if (std::is_same::value) - else { - pooling_nhwc_element2_kernel<<>>( - (half2*)(ref_output.data()), - (const half2*)(ref_input.data()), - N, - H, - W, - C/2, - output_H, - output_W, - kernel_H, - kernel_W, - stride_H, - stride_W, - padding_H, - padding_W); - } - } - } - } -} - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h deleted file mode 100644 index babfecd39205ebff39794133868e4a95b7e9525c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h +++ /dev/null @@ -1,144 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. - */ - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" - -namespace cutlass { - -/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. - * \tparam T: data type - */ -template -void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream); - - -template -__global__ void nhwc_to_nchw_kernel(T *output, - const T *input, - const int n, - const int h, - const int w, - const int c) { - - const int hw = h*w; - const int hwc = hw*c; - __shared__ T shbuf[32 * (32 + 1)]; - const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; - const int32_t wid = tid / 32; - const int32_t lid = tid % 32; - const int32_t ni = blockIdx.z; - const int32_t hwi0 = blockIdx.y * 32; - const int32_t ci0 = blockIdx.x * 32; - - const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; - const T *A = input + input_idx; - if (ci0 + lid < c) { - const int lid_x_33 = lid * 33; - if ((hwi0 + 32) <= hw) { - int hwi = wid; // between 0 and 7 - CUTLASS_PRAGMA_UNROLL - for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { - shbuf[lid_x_33 + hwi] = A[lid]; - A = &A[8 * c]; - hwi += 8; - } - } else { - for (int hwi = wid; hwi < 32; hwi += 8) { - if ((hwi + hwi0) < hw) { - shbuf[lid_x_33 + hwi] = A[lid]; - } - A = &A[8 * c]; - } - } - } - __syncthreads(); - - const int32_t hwiOut = hwi0 + lid; - output = &output[ni * hwc + hwiOut]; - if (hwiOut < hw) { - if (ci0 + 32 < c) { - int cI = wid; - CUTLASS_PRAGMA_UNROLL - for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { - output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; - cI += 8; - } - } else { - for (int cI = wid; cI < 32; cI += 8) { - if (ci0 + cI < c) { - output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; - } - } - } - } -} - -template -void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, - cutlass::Tensor4DCoord output_tensor_size, - TensorRef ref_input, - TensorRef ref_output, - cudaStream_t stream) { - - assert( - input_tensor_size.n() == output_tensor_size.n() && - input_tensor_size.h() == output_tensor_size.c() && - input_tensor_size.w() == output_tensor_size.h() && - input_tensor_size.c() == output_tensor_size.w()); - - int n = input_tensor_size.n(); - int h = input_tensor_size.h(); - int w = input_tensor_size.w(); - int c = input_tensor_size.c(); - - dim3 grid((c + 31)/32, (h*w + 31)/32, n); - dim3 block(32, 8); - nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), - n, h, w, c); - -} - -} //namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h deleted file mode 100644 index 0d1b1af56e4463640edc3e9c82533baf815c9b27..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h +++ /dev/null @@ -1,186 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/util/device_utils.h" -#include - -namespace cutlass { - -__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, - const float4 *weight, - const int m, const int n, float epsilon) { - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean; - float local_sums[1] = {0.0f}; - const int n_8 = n / 8; - int offset = m_idx * n_8; - input += offset; - output += offset; - - for (int index = tid; index < n_8; index += bdimx) { - const float4 local_val = input[index]; - const half2 *h1 = (half2 *)&local_val.x; - const half2 *h2 = (half2 *)&local_val.y; - const half2 *h3 = (half2 *)&local_val.z; - const half2 *h4 = (half2 *)&local_val.w; - local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + - static_cast(h1->y) * static_cast(h1->y) + - static_cast(h2->x) * static_cast(h2->x) + - static_cast(h2->y) * static_cast(h2->y) + - static_cast(h3->x) * static_cast(h3->x) + - static_cast(h3->y) * static_cast(h3->y) + - static_cast(h4->x) * static_cast(h4->x) + - static_cast(h4->y) * static_cast(h4->y); - } - - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = rsqrtf(local_sums[0] / n + epsilon); - } - __syncthreads(); - - for (int index = tid; index < n_8; index += bdimx) { - const float4 local_val = input[index]; - const float4 weight_val = weight[index]; - - const half2 *l1 = (half2 *)&local_val.x; - const half2 *l2 = (half2 *)&local_val.y; - const half2 *l3 = (half2 *)&local_val.z; - const half2 *l4 = (half2 *)&local_val.w; - - const half2 *g1 = (half2 *)&weight_val.x; - const half2 *g2 = (half2 *)&weight_val.y; - const half2 *g3 = (half2 *)&weight_val.z; - const half2 *g4 = (half2 *)&weight_val.w; - - float4 tmp; - half2 *h1 = (half2 *)&tmp.x; - half2 *h2 = (half2 *)&tmp.y; - half2 *h3 = (half2 *)&tmp.z; - half2 *h4 = (half2 *)&tmp.w; - - h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); - h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); - h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); - h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); - h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); - h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); - h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); - h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); - - output[index] = tmp; - } -} - -template -__global__ void rmsnorm_twoPassAlgo_e1(T* output, - const T* input, - const T* weight, - const int m, const int n, - float epsilon) -{ - const int m_idx = blockIdx.x; - const int tid = threadIdx.x; - const int bdimx = blockDim.x; - __shared__ float s_mean; - float local_sums[1] = {0.0f}; - int offset = m_idx * n; - input += offset; - output += offset; - - for (int index = tid ; index < n ; index += bdimx){ - float local_val = static_cast(input[index]); - local_sums[0] += local_val * local_val; - } - if (blockDim.x <= 32) { - warpReduceSum(local_sums); - } - else { - blockReduceSum(local_sums); - } - if (threadIdx.x == 0) { - s_mean = rsqrtf(local_sums[0] / n + epsilon); - } - __syncthreads(); - - for (int index = tid ; index < n ; index += bdimx){ - const T weight_val = weight[index]; - const T local_val = input[index]; - output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); - } -} - -template -void rmsnorm(cutlass::MatrixCoord tensor_size, - TensorRef ref_output, - TensorRef ref_input, - TensorRef ref_weight, - cudaStream_t stream, float epsilon = 1e-5f){ - const int m = tensor_size.row(); - const int n = tensor_size.column(); - T* output = ref_output.data(); - const T* input = ref_input.data(); - const T* weight = ref_weight.data(); - dim3 grid(m); - - if (n % 8 == 0 && std::is_same::value) { - dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); - - rmsnorm_twoPassAlgo_e8<<>>( - (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); - } else { - dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); - - rmsnorm_twoPassAlgo_e1<<>>( - output, input, weight, m, n, epsilon); - } - - auto result = cudaGetLastError(); - if (result != cudaSuccess) { - std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; - abort(); - } -} - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h deleted file mode 100644 index 9747d50975d7d35df287f6b056aedc489adb317c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h +++ /dev/null @@ -1,127 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief utils code for device cutlass code -*/ - -#pragma once - -#include -#include -#define FINAL_MASK 0xffffffff - -struct half4 { - half x, y, z, w; -}; - -template -__inline__ __device__ T warpReduceSum(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceSum(T* val) -{ - __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSum(val); - - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } - - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } - warpReduceSum(val); - return (T)0.0f; -} - -template -__inline__ __device__ T warpReduceMax(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceMax(T* val) -{ - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - warpReduceMax(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[wid][i] = val[i]; - } - } - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); - } - warpReduceMax(val); - - return (T)0.0f; -} - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h deleted file mode 100644 index 6565aba9607ad68defacb6e98d9f9bbc944cd48d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h +++ /dev/null @@ -1,157 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -/*! \file - \brief This header contains a class to parametrize a statistical distribution function. -*/ - -#include - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Distribution type -struct Distribution { - /// Variant types - enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; - - /// Distribution state - union { - /// Uniform distribution - struct { - double min; - double max; - // Percent elements set to NaN - double pnan; - } uniform; - - /// Gaussian distribution - struct { - double mean; - double stddev; - double pnz; - double pnzA; - double pnzB; - double pnzC; - } gaussian; - - /// Elements are linear combination of row and column index - struct { - double start; - double delta; - } sequential; - }; - - /// Active variant kind - Kind kind; - - /// Random values are cast to integer after scaling by this power of two - int int_scale; - - // - // Methods - // - - Distribution() : kind(Invalid), int_scale(0) {} - -/// Configures distribution as uniform random - Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { - kind = Uniform; - uniform.min = _min; - uniform.max = _max; - int_scale = _int_scale; - uniform.pnan = _pnan; - return *this; - } - - /// Configures distribution as Gaussian distribution - Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { - kind = Gaussian; - gaussian.mean = _mean; - gaussian.stddev = _stddev; - gaussian.pnz = _pnz; - gaussian.pnzA = _pnz; - gaussian.pnzB = _pnz; - gaussian.pnzC = _pnz; - int_scale = _int_scale; - return *this; - } - - /// Sets identity - Distribution &set_identity() { - kind = Identity; - return *this; - } - - /// Sets sequential - Distribution &set_sequential(double start, double delta, int _int_scale = 0) { - kind = Sequential; - sequential.start = start; - sequential.delta = delta; - int_scale = _int_scale; - return *this; - } -}; - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Prints a Distribution to ostream -inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { - switch (dist.kind) { - case cutlass::Distribution::Uniform: - out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max - << ", pnan: " << dist.uniform.pnan; - break; - case cutlass::Distribution::Gaussian: - out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev - << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " - << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; - break; - case cutlass::Distribution::Identity: - out << "identity"; - break; - case cutlass::Distribution::Sequential: - out << "sequential"; - break; - default: - out << "unknown"; - } - - out << ", int_scale: " << dist.int_scale; - - return out; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h deleted file mode 100644 index f2b7df6cb1c465a312d76566768cb79fcdfffee4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h +++ /dev/null @@ -1,69 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -/** - * \file - * \brief C++ exception semantics for CUDA error codes - */ - -#include -#include -#include - -#include "cutlass/platform/platform.h" - -namespace cutlass { - -/// C++ exception wrapper for CUDA \p cudaError_t -class cuda_exception : public std::exception { - public: - /// Constructor - cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} - - /// Returns the underlying CUDA \p cudaError_t - cudaError_t cudaError() const { return err; } - - protected: - /// Explanatory string - const char* msg; - - /// Underlying CUDA \p cudaError_t - cudaError_t err; -}; - -/// Writes a cuda_exception instance to an output stream -inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { - return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); -} - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp deleted file mode 100644 index be2264466e350c062900a50e27e923847186d084..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp +++ /dev/null @@ -1,369 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief GETT command line parser to gather semantic modes, their stride order, and extents. -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "cutlass/util/command_line.h" - -namespace cutlass { - -// Output shortcuts -std::ostream& operator<<(std::ostream& os, std::vector data) { - for (auto& a : data) os << a; - return os; -} - -template -std::ostream& operator<<(std::ostream& os, std::vector data) { - for (auto& a : data) os << a << " "; - return os; -} - -struct GettCommandLine { - struct GettProblem { - using extent_type = int; - using stride_type = int64_t; - - // Row modes: appear in A and C/D - std::vector M; - std::vector ldAm; - std::vector ldCm; - - // Column modes: appear in B and C/D - std::vector N; - std::vector ldBn; - std::vector ldCn; - - // Reduction modes: appear in A and B - std::vector K; - std::vector ldAk; - std::vector ldBk; - - // Batch modes: appear in all in/out tensors - std::vector L; - std::vector ldAl; - std::vector ldBl; - std::vector ldCl; - }; - - static GettProblem - parse(int argc, char const* argv[], bool parse_verbose = false) { - using extent_type = typename GettProblem::extent_type; - using stride_type = typename GettProblem::stride_type; - - cutlass::CommandLine cmd(argc, argv); - - // modeA - std::vector a_mode; - cmd.get_cmd_line_arguments("modeA", a_mode); - - // modeB - std::vector b_mode; - cmd.get_cmd_line_arguments("modeB", b_mode); - - // modeC - std::vector c_mode; - cmd.get_cmd_line_arguments("modeC", c_mode); - - - // mode_sizes - std::map mode_size; - // First, initialize all modes in a, b, c to make sure they're in map - for (char a : a_mode) mode_size[a] = 1; - for (char b : b_mode) mode_size[b] = 1; - for (char c : c_mode) mode_size[c] = 1; - - // Then, overwrite the ones in -extent - std::vector > extent_tokens; - cmd.get_cmd_line_argument_pairs("extents", extent_tokens); - for (auto e : extent_tokens) { - if (std::get<0>(e).size() > 1) { - std::cerr << "ERROR: Mode name must only be 1 character long.\n"; - print_usage(); - exit(1); - } - char label = std::get<0>(e)[0]; - int size = std::stoi(std::get<1>(e)); - mode_size[label] = size; - } - - // Print out symbolic modes and their extents - if (parse_verbose) { - std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; - for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; - } - - // - // Collect/Compute strides - // - - std::map mode_ldA; - std::map mode_ldB; - std::map mode_ldC; - - { - stride_type current; - - current = 1; - for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } - - current = 1; - for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } - - current = 1; - for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } - } - - // - // Collect mode categories - // - - std::vector row_mode; // rows - std::vector col_mode; // columns - std::vector red_mode; // reductions - std::vector bat_mode; // batches - - { - std::vector a_label = a_mode; - std::vector b_label = b_mode; - std::vector c_label = c_mode; - - std::sort(std::begin(a_label), std::end(a_label)); - std::sort(std::begin(b_label), std::end(b_label)); - std::sort(std::begin(c_label), std::end(c_label)); - - // std::set_intersections to find semantic category of each symbolic mode - std::set_intersection(std::begin(a_label), std::end(a_label), - std::begin(c_label), std::end(c_label), - std::back_inserter(row_mode)); - - std::set_intersection(std::begin(b_label), std::end(b_label), - std::begin(c_label), std::end(c_label), - std::back_inserter(col_mode)); - - std::set_intersection(std::begin(a_label), std::end(a_label), - std::begin(b_label), std::end(b_label), - std::back_inserter(red_mode)); - - std::set_intersection(std::begin(row_mode), std::end(row_mode), - std::begin(col_mode), std::end(col_mode), - std::back_inserter(bat_mode)); - - // std::set_difference to remove batch modes from other semantic modes - for (char l : bat_mode) { - row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); - col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); - red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); - } - } - - // Print out the semantic association of each symbolic mode - if (parse_verbose) { - std::cout << " rows : " << row_mode << '\n'; - std::cout << " cols : " << col_mode << '\n'; - std::cout << " reds : " << red_mode << '\n'; - std::cout << " bats : " << bat_mode << '\n'; - } - - // - // Permute modes - // - - // Permute the batched modes to promote coalescing - // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size - std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { - return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) - < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); - }); - // Compute sizes and strides of ordered reduction modes - std::vector L; - std::vector ldAl; - std::vector ldBl; - std::vector ldCl; - for (char l : bat_mode) { - L.push_back(mode_size[l]); - ldAl.push_back(mode_ldA[l]); - ldBl.push_back(mode_ldB[l]); - ldCl.push_back(mode_ldC[l]); - } - - // Permute the reduction modes to promote coalescing - // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size - std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { - return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) - < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); - }); - // Compute sizes and strides of ordered reduction modes - std::vector K; - std::vector ldAk; - std::vector ldBk; - for (char k : red_mode) { - K.push_back(mode_size[k]); - ldAk.push_back(mode_ldA[k]); - ldBk.push_back(mode_ldB[k]); - } - - // Permute the row modes to promote coalescing - // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm - std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { - return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) - < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); - }); - // Compute sizes and strides of ordered row modes - std::vector M; - std::vector ldAm; - std::vector ldCm; - for (char m : row_mode) { - M.push_back(mode_size[m]); - ldAm.push_back(mode_ldA[m]); - ldCm.push_back(mode_ldC[m]); - } - - // Permute the col modes to promote coalescing - // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn - std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { - return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) - < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); - }); - // Compute sizes and strides of ordered col modes - std::vector N; - std::vector ldBn; - std::vector ldCn; - for (char n : col_mode) { - N.push_back(mode_size[n]); - ldBn.push_back(mode_ldB[n]); - ldCn.push_back(mode_ldC[n]); - } - - if (parse_verbose) { - std::cout << "C_"; - if (! row_mode.empty()) { - std::cout << "(" << row_mode << ")"; - } - if (! col_mode.empty()) { - std::cout << "(" << col_mode << ")"; - } - if (! bat_mode.empty()) { - std::cout << "(" << bat_mode << ")"; - } - std::cout << " = A_"; - if (! row_mode.empty()) { - std::cout << "(" << row_mode << ")"; - } - if (! red_mode.empty()) { - std::cout << "(" << red_mode << ")"; - } - if (! bat_mode.empty()) { - std::cout << "(" << bat_mode << ")"; - } - std::cout << " * B_"; - if (! col_mode.empty()) { - std::cout << "(" << col_mode << ")"; - } - if (! red_mode.empty()) { - std::cout << "(" << red_mode << ")"; - } - if (! bat_mode.empty()) { - std::cout << "(" << bat_mode << ")"; - } - std::cout << '\n'; - - int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); - int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); - int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); - int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); - - std::cout << " M : (" << M_size << ") "; - for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; - std::cout << '\n'; - std::cout << " N : (" << N_size << ") "; - for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; - std::cout << '\n'; - std::cout << " K : (" << K_size << ") "; - for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; - std::cout << '\n'; - std::cout << " L : (" << L_size << ") "; - for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; - std::cout << '\n'; - - std::cout << " ldAm : " << ldAm << '\n'; - std::cout << " ldAk : " << ldAk << '\n'; - std::cout << " ldAl : " << ldAl << '\n'; - std::cout << " ldBn : " << ldBn << '\n'; - std::cout << " ldBk : " << ldBk << '\n'; - std::cout << " ldBl : " << ldBl << '\n'; - std::cout << " ldCm : " << ldCm << '\n'; - std::cout << " ldCn : " << ldCn << '\n'; - std::cout << " ldCl : " << ldCl << '\n'; - } - - return {M, ldAm, ldCm, - N, ldBn, ldCn, - K, ldAk, ldBk, - L, ldAl, ldBl, ldCl}; - } - - static void - print_usage() { - std::cout << - "GETT problem command line parser:\n" - " --modeA=\n" - " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" - " The semantic association of each symbolic mode is determined automatically.\n\n" - - " --modeB=\n" - " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" - " The semantic association of each symbolic mode is determined automatically.\n\n" - - " --modeC=\n" - " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" - " The semantic association of each symbolic mode is determined automatically.\n\n" - - " --extents=\n" - " A command delimited list of symbolic mode and its corresponding extent.\n" - " Extents are defaulted to 1 if any are not provided.\n\n" - - "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; - } -}; - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp deleted file mode 100644 index 58d08b860c9e665d170fd022ed0d95875e029019..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp +++ /dev/null @@ -1,116 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include - -#include - -namespace cute -{ - -void -device_init(int device_id, bool quiet = false) -{ - cudaDeviceProp device_prop; - std::size_t device_free_physmem; - std::size_t device_total_physmem; - - CUTE_CHECK_ERROR(cudaSetDevice(device_id)); - CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); - CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); - - if (device_prop.major < 1) { - fprintf(stderr, "Device does not support CUDA.\n"); - exit(1); - } - - //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; - - if (!quiet) { - printf("Using device %d: %s (SM%d, %d SMs)\n", - device_id, device_prop.name, - device_prop.major * 10 + device_prop.minor, - device_prop.multiProcessorCount); - fflush(stdout); - } -} - -/** - * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. - */ -inline int -_ConvertSMVer2Cores(int major, int minor) -{ - // Defines for GPU Architecture types (using the SM version to determine - // the # of cores per SM - typedef struct { - int SM; // 0xMm (hexadecimal notation), M = SM Major version, - // and m = SM minor version - int Cores; - } sSMtoCores; - - sSMtoCores nGpuArchCoresPerSM[] = { - {0x30, 192}, - {0x32, 192}, - {0x35, 192}, - {0x37, 192}, - {0x50, 128}, - {0x52, 128}, - {0x53, 128}, - {0x60, 64}, - {0x61, 128}, - {0x62, 128}, - {0x70, 64}, - {0x72, 64}, - {0x75, 64}, - {-1, -1}}; - - int index = 0; - - while (nGpuArchCoresPerSM[index].SM != -1) { - if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { - return nGpuArchCoresPerSM[index].Cores; - } - index++; - } - - // If we don't find the values, we default use the previous one - // to run properly - printf("MapSMtoCores for SM %d.%d is undefined." - " Default to use %d Cores/SM\n", - major, minor, nGpuArchCoresPerSM[index - 1].Cores); - - return nGpuArchCoresPerSM[index - 1].Cores; -} - -} // end namespace cute diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h deleted file mode 100644 index 4e7718059dfaea0c77d7ebf67789f307b4ca0cf6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h +++ /dev/null @@ -1,111 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief reorder data from the host side -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/tensor_view.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { - -/// This is needed for the interleaved integer tensor core kernels. The purpose -/// is to use skip the shared memory part in the epilogue. -template -void reorder_column(TensorRef dest, - TensorRef src, - cutlass::gemm::GemmCoord problem_size) { - const int InstructionShapeCol = 8; - // 4 threads per Quad - const int ElementsPerThread = InstructionShapeCol / 4; - // 4 threads per Quad - const int ReorderedElementsPerThread = - Interleaved / 4; - - for (int n = 0; n < problem_size.n(); n++) { - for (int k = 0; k < problem_size.k(); k++) { - dest.at({k, (n / Interleaved) * Interleaved + - ((n % ReorderedElementsPerThread) / ElementsPerThread) * - InstructionShapeCol + - ((n % Interleaved) / ReorderedElementsPerThread) * - ElementsPerThread + - (n % ElementsPerThread)}) = src.at({k, n}); - } - } -} - -template -void reorder_convK(TensorRef dest, - TensorRef src, - cutlass::gemm::GemmCoord problem_size) { - - TensorRef> mappedDest(dest.data(), dest.stride(0)); - TensorRef> mappedSrc(src.data(), src.stride(0)); - - reorder_column( - mappedDest, mappedSrc, problem_size); -} - -/// This is needed for the sparse tensor core kernels. The purpose -/// is to use ldmatrix to load from shared memory to the register file. -template -void reorder_meta(TensorRef dest, - TensorRef src, - cutlass::gemm::GemmCoord problem_size) { - for (int m = 0; m < problem_size.m(); m++) { - for (int k = 0; k < problem_size.k(); k++) { - // First reorder the rows. - int group = (sizeof(Element) == 2) ? 32 : 16; - int interweave = (sizeof(Element) == 2) ? 4 : 2; - - int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; - int dest_col = k; - - // Next swizzle the 2x2 blocks from Z to N. - if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { - ++dest_row; - --dest_col; - } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { - --dest_row; - ++dest_col; - } - - dest.at({dest_row, dest_col}) = src.at({m, k}); - } - } -} -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h deleted file mode 100644 index 3226055ad0836e7a3059340ff16d54594987e0c8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h +++ /dev/null @@ -1,541 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -/*! \file - \brief HostTensor contributes management for both host and device memory. - - HostTensor allocates host and device memory upon construction. Basic element-wise operations on - host memory synchronize device memory automatically. Explicit copy operations provide abstractions - for CUDA memcpy operations. - - Call {host, device}_{data, ref, view}() for accessing host or device memory. - - See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/fast_math.h" - -#include "device_memory.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Host tensor -template < - /// Data type of element stored within tensor (concept: NumericType) - typename Element_, - /// Defines a mapping from logical coordinate to linear memory (concept: Layout) - typename Layout_ -> -class HostTensor { -public: - - /// Data type of individual access - using Element = Element_; - - /// Mapping function from logical coordinate to linear memory - using Layout = Layout_; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Layout's stride vector - using Stride = typename Layout::Stride; - - /// Tensor reference to device memory - using TensorRef = TensorRef; - - /// Tensor reference to constant device memory - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - /// Tensor reference to device memory - using TensorView = TensorView; - - /// Tensor reference to constant device memory - using ConstTensorView = typename TensorView::ConstTensorView; - - /// Reference to element in tensor - using Reference = typename TensorRef::Reference; - - /// Constant reference to element in tensor - using ConstReference = typename ConstTensorRef::Reference; - -private: - using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization - typename platform::conditional_t::value % 8 == 0, // Handle subbyte types - Element, uint8_t>>; - using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; - static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; - static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; - static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; - static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; - - // - // Data members - // - - /// Extent of tensor in logical dimensions - TensorCoord extent_; - - /// Layout object - Layout layout_; - - /// Host-side memory allocation - std::vector host_; - - /// Device-side memory - device_memory::allocation device_; - - /// number of containers - size_t count_to_container_storage_unit_count(size_t count) { - return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; - } - -public: - // - // Device and Host Methods - // - - /// Default constructor - HostTensor() {} - - /// Constructs a tensor given an extent. Assumes a packed layout - HostTensor( - TensorCoord const &extent, - bool device_backed = true - ) { - - this->reset(extent, Layout::packed(extent), device_backed); - } - - /// Constructs a tensor given an extent and layout - HostTensor( - TensorCoord const &extent, - Layout const &layout, - bool device_backed = true - ) { - - this->reset(extent, layout, device_backed); - } - - ~HostTensor() { } - - /// Clears the HostTensor allocation to size/capacity = 0 - void reset() { - extent_ = TensorCoord(); - layout_ = Layout::packed(extent_); - - host_.clear(); - device_.reset(); - } - - /// Resizes internal memory allocations without affecting layout or extent - void reserve( - size_t count, ///< size of tensor in elements - bool device_backed_ = true) { ///< if true, device memory is also allocated -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); -#endif - - device_.reset(); - host_.clear(); - - size_t count_container = count_to_container_storage_unit_count(count); -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); -#endif - host_.resize(count_container); - - // Allocate memory - StorageUnit* device_memory = nullptr; - if (device_backed_) { -#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); -#endif - device_memory = device_memory::allocate(count_container); - } - device_.reset(device_memory, device_backed_ ? count_container : 0); - } - - /// Updates the extent and layout of the HostTensor. Allocates memory according to the new - /// extent and layout. - void reset( - TensorCoord const &extent, ///< extent of logical tensor - Layout const &layout, ///< layout object of tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - extent_ = extent; - layout_ = layout; - - reserve(size_t(layout_.capacity(extent_)), device_backed_); - } - - /// Updates the extent and layout of the HostTensor. Allocates memory according to the new - /// extent and layout. Assumes a packed tensor configuration. - void reset( - TensorCoord const &extent, ///< extent of logical tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - reset(extent, Layout::packed(extent), device_backed_); - } - - /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. - /// To force allocation, call reset(). - void resize( - TensorCoord const &extent, ///< extent of logical tensor - Layout const &layout, ///< layout object of tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - extent_ = extent; - layout_ = layout; - - LongIndex new_size = size_t(layout_.capacity(extent_)); - LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); - - if (static_cast(new_size_container) > host_.size()) { - reserve(new_size, device_backed_); - } - } - - /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. - /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. - void resize( - TensorCoord const &extent, ///< extent of logical tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - resize(extent, Layout::packed(extent), device_backed_); - } - - /// Returns the logical number of elements stored in the host tensor - size_t size() const { - return layout_.capacity(extent_); - } - - /// Returns the logical capacity in terms of number of elements. May be larger than the size(). - LongIndex capacity() const { - return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; - } - - /// Gets pointer to host data - Element * host_data() { return reinterpret_cast(host_.data()); } - - /// Gets pointer to host data with a pointer offset - Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } - - /// Gets a reference to an element in host memory - Reference host_data(LongIndex idx) { - return ReferenceFactory::get(host_data(), idx); - } - - /// Gets pointer to host data - Element const * host_data() const { return reinterpret_cast(host_.data()); } - - /// Gets pointer to host data with a pointer offset - Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } - - /// Gets a constant reference to an element in host memory - ConstReference host_data(LongIndex idx) const { - return ReferenceFactory::get(host_data(), idx); - } - - /// Gets pointer to device data - Element * device_data() { return reinterpret_cast(device_.get()); } - - /// Gets pointer to device data - Element const * device_data() const { return reinterpret_cast(device_.get()); } - - /// Gets pointer to device data with a pointer offset - Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } - - /// Gets pointer to device data with a pointer offset - Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } - - /// Accesses the tensor reference pointing to data - TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } - - /// Accesses the tensor reference pointing to data - ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } - - /// Accesses the tensor reference pointing to data - TensorRef device_ref(LongIndex ptr_element_offset=0) { - return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { - return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); - } - - /// Accesses the tensor reference pointing to data - TensorView host_view(LongIndex ptr_element_offset=0) { - return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorView host_view(LongIndex ptr_element_offset=0) const { - return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - TensorView device_view(LongIndex ptr_element_offset=0) { - return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorView device_view(LongIndex ptr_element_offset=0) const { - return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); - } - - /// Returns true if device memory is allocated - bool device_backed() const { - return (device_.get() == nullptr) ? false : true; - } - - - /// Returns the layout object - Layout & layout() { - return layout_; - } - - /// Returns the layout object - Layout layout() const { - return layout_; - } - - /// Returns the layout object's stride vector - Stride stride() const { - return layout_.stride(); - } - - /// Returns the layout object's stride vector - Stride & stride() { - return layout_.stride(); - } - - /// Returns the layout object's stride in a given physical dimension - LongIndex stride(int dim) const { - return layout_.stride().at(dim); - } - - /// Returns the layout object's stride in a given physical dimension - LongIndex & stride(int dim) { - return layout_.stride().at(dim); - } - - /// Computes the offset of an index from the origin of the tensor - LongIndex offset(TensorCoord const& coord) const { - return layout_(coord); - } - - /// Returns a reference to the element at the logical Coord in host memory - Reference at(TensorCoord const& coord) { - return host_data(offset(coord)); - } - - /// Returns a const reference to the element at the logical Coord in host memory - ConstReference at(TensorCoord const& coord) const { - return host_data(offset(coord)); - } - - /// Returns the extent of the tensor - TensorCoord extent() const { - return extent_; - } - - /// Returns the extent of the tensor - TensorCoord & extent() { - return extent_; - } - - /// Copies data from device to host - void sync_host() { - if (device_backed()) { - device_memory::copy_to_host( - host_.data(), device_.get(), device_.size()); - } - } - - /// Copies data from host to device - void sync_device() { - if (device_backed()) { - device_memory::copy_to_device( - device_.get(), host_.data(), host_.size()); - } - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_device_to_host( - Element const* ptr_device, ///< source device memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_to_host( - host_.data(), reinterpret_cast(ptr_device), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_device_to_device( - Element const* ptr_device, ///< source device memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_device_to_device( - device_.get(), reinterpret_cast(ptr_device), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_host_to_device( - Element const* ptr_host, ///< source host memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_to_device( - device_.get(), reinterpret_cast(ptr_host), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_host_to_host( - Element const* ptr_host, ///< source host memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_host_to_host( - host_.data(), reinterpret_cast(ptr_host), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_device_to_host( - Element * ptr_host, ///< source device memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_to_host( - reinterpret_cast(ptr_host), device_.get(), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_device_to_device( - Element * ptr_device, ///< source device memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_device_to_device( - reinterpret_cast(ptr_device), device_.get(), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_host_to_device( - Element * ptr_device, ///< source host memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_to_device( - reinterpret_cast(ptr_device), host_.data(), container_count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_host_to_host( - Element * ptr_host, ///< source host memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - size_t container_count = count_to_container_storage_unit_count(count); - device_memory::copy_host_to_host( - reinterpret_cast(ptr_host), host_.data(), container_count); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h deleted file mode 100644 index ca770e4d76cfe2df16309baca0b2de8ab6de98c4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h +++ /dev/null @@ -1,591 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -/*! \file - \brief HostTensor contributes management for both host and device memory. - - HostTensor allocates host and device memory upon construction. Basic element-wise operations on - host memory synchronize device memory automatically. Explicit copy operations provide abstractions - for CUDA memcpy operations. - - Call {host, device}_{data, ref, view}() for accessing host or device memory. - - See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -*/ - -#include - -#include "cutlass/cutlass.h" - -#include "cutlass/tensor_ref_planar_complex.h" -#include "cutlass/tensor_view_planar_complex.h" - -#include "device_memory.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Host tensor -template < - /// Data type of element stored within tensor (concept: NumericType) - typename Element_, - /// Defines a mapping from logical coordinate to linear memory (concept: Layout) - typename Layout_ -> -class HostTensorPlanarComplex { -public: - - /// Data type of individual access - using Element = Element_; - - /// Mapping function from logical coordinate to linear memory - using Layout = Layout_; - - /// Logical rank of tensor index space - static int const kRank = Layout::kRank; - - /// Index type - using Index = typename Layout::Index; - - /// Long index used for pointer offsets - using LongIndex = typename Layout::LongIndex; - - /// Coordinate in logical tensor space - using TensorCoord = typename Layout::TensorCoord; - - /// Layout's stride vector - using Stride = typename Layout::Stride; - - /// Tensor reference to device memory - using TensorRef = TensorRefPlanarComplex; - - /// Tensor reference to constant device memory - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - /// Tensor reference to device memory - using TensorView = TensorViewPlanarComplex; - - /// Tensor reference to constant device memory - using ConstTensorView = typename TensorView::ConstTensorView; - - /// Reference to element in tensor - using Reference = typename TensorRef::Reference; - - /// Constant reference to element in tensor - using ConstReference = typename ConstTensorRef::Reference; - - private: - - // - // Data members - // - - /// Extent of tensor in logical dimensions - TensorCoord extent_; - - /// Layout object - Layout layout_; - - /// Host-side memory allocation - std::vector host_; - - /// Device-side memory - device_memory::allocation device_; - - public: - // - // Device and Host Methods - // - - /// Default constructor - HostTensorPlanarComplex() {} - - /// Constructs a tensor given an extent. Assumes a packed layout - HostTensorPlanarComplex( - TensorCoord const &extent, - bool device_backed = true - ) { - - this->reset(extent, Layout::packed(extent), device_backed); - } - - /// Constructs a tensor given an extent and layout - HostTensorPlanarComplex( - TensorCoord const &extent, - Layout const &layout, - bool device_backed = true - ) { - - this->reset(extent, layout, device_backed); - } - - ~HostTensorPlanarComplex() { } - - /// Clears the HostTensor allocation to size/capacity = 0 - void reset() { - extent_ = TensorCoord(); - layout_ = Layout::packed(extent_); - - host_.clear(); - device_.reset(); - } - - /// Resizes internal memory allocations without affecting layout or extent - void reserve( - size_t count, ///< size of tensor in elements - bool device_backed_ = true) { ///< if true, device memory is also allocated - - device_.reset(); - host_.clear(); - - host_.resize(count * 2); - - // Allocate memory - Element* device_memory = nullptr; - if (device_backed_) { - device_memory = device_memory::allocate(count * 2); - } - device_.reset(device_memory, device_backed_ ? count * 2 : 0); - } - - /// Updates the extent and layout of the HostTensor. Allocates memory according to the new - /// extent and layout. - void reset( - TensorCoord const &extent, ///< extent of logical tensor - Layout const &layout, ///< layout object of tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - extent_ = extent; - layout_ = layout; - - reserve(size_t(layout_.capacity(extent_)), device_backed_); - } - - /// Updates the extent and layout of the HostTensor. Allocates memory according to the new - /// extent and layout. Assumes a packed tensor configuration. - void reset( - TensorCoord const &extent, ///< extent of logical tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - reset(extent, Layout::packed(extent), device_backed_); - } - - /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. - /// To force allocation, call reset(). - void resize( - TensorCoord const &extent, ///< extent of logical tensor - Layout const &layout, ///< layout object of tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - extent_ = extent; - layout_ = layout; - - LongIndex new_size = size_t(layout_.capacity(extent_)); - - if (static_cast(new_size * 2) > host_.size()) { - reserve(new_size); - } - } - - /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. - /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. - void resize( - TensorCoord const &extent, ///< extent of logical tensor - bool device_backed_ = true) { ///< if true, device memory is also allocated. - - resize(extent, Layout::packed(extent), device_backed_); - } - - /// Returns the number of elements stored in the host tensor - size_t size() const { - return host_.size() / 2; - } - - /// Returns the logical capacity based on extent and layout. May differ from size(). - LongIndex capacity() const { - return layout_.capacity(extent_); - } - - /// Stride between real and imaginary parts - LongIndex imaginary_stride() const { - return host_.size() / 2; - } - - /// Gets pointer to host data - Element * host_data() { return host_.data(); } - - /// Gets pointer to host data imaginary part - Element * host_data_imag() { return host_.data() + imaginary_stride(); } - - /// Gets pointer to host data with a pointer offset - Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } - - /// Gets pointer to host data with a pointer offset - Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } - - /// Gets a reference to an element in host memory - Reference host_data(LongIndex idx) { - return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); - } - - /// Gets pointer to host data - Element const * host_data() const { return host_.data(); } - - /// Gets pointer to host data imaginary part - Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } - - /// Gets a constant reference to an element in host memory - ConstReference host_data(LongIndex idx) const { - return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); - } - - /// Gets pointer to device data - Element * device_data() { return device_.get(); } - - /// Gets pointer to device data with a pointer offset - Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } - - /// Gets pointer to device data - Element const * device_data() const { return device_.get(); } - - /// Gets pointer to device data with a pointer offset - Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } - - /// Gets a pointer to the device data imaginary part - Element * device_data_imag() { return device_.get() + imaginary_stride(); } - - /// Accesses the tensor reference pointing to data - TensorRef host_ref(LongIndex ptr_element_offset=0) { - return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); - } - - /// Returns a tensor reference to the real part of the tensor - cutlass::TensorRef host_ref_real() { - return cutlass::TensorRef(host_data(), layout_); - } - - /// Returns a tensor reference to the real part of the tensor - cutlass::TensorRef host_ref_imag() { - return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { - return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); - } - - /// Accesses the tensor reference pointing to data - TensorRef device_ref(LongIndex ptr_element_offset=0) { - return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); - } - - /// Accesses the tensor reference pointing to data - ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { - return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); - } - - /// Returns a tensor reference to the real part of the tensor - cutlass::TensorRef device_ref_real() { - return cutlass::TensorRef(device_data(), layout_); - } - - /// Returns a tensor reference to the real part of the tensor - cutlass::TensorRef device_ref_imag() { - return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); - } - - /// Accesses the tensor reference pointing to data - TensorView host_view(LongIndex ptr_element_offset=0) { - return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorView host_view(LongIndex ptr_element_offset=0) const { - return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); - } - - /// Accesses the tensor reference pointing to data - cutlass::TensorView host_view_real() { - return cutlass::TensorView(host_data(), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - cutlass::TensorView host_view_imag() { - return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - TensorView device_view(LongIndex ptr_element_offset=0) { - return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); - } - - /// Accesses the tensor reference pointing to data - ConstTensorView device_view(LongIndex ptr_element_offset=0) const { - return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); - } - - /// Accesses the tensor reference pointing to data - cutlass::TensorView device_view_real() { - return cutlass::TensorView(device_data(), layout_, extent_); - } - - /// Accesses the tensor reference pointing to data - cutlass::TensorView device_view_imag() { - return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); - } - - /// Returns true if device memory is allocated - bool device_backed() const { - return (device_.get() == nullptr) ? false : true; - } - - /// Returns the layout object - Layout layout() const { - return layout_; - } - - /// Returns the layout object's stride vector - Stride stride() const { - return layout_.stride(); - } - - /// Returns the layout object's stride in a given physical dimension - Index stride(int dim) const { - return layout_.stride().at(dim); - } - - /// Computes the offset of an index from the origin of the tensor - LongIndex offset(TensorCoord const& coord) const { - return layout_(coord); - } - - /// Returns a reference to the element at the logical Coord in host memory - Reference at(TensorCoord const& coord) { - return host_data(offset(coord)); - } - - /// Returns a const reference to the element at the logical Coord in host memory - ConstReference at(TensorCoord const& coord) const { - return host_data(offset(coord)); - } - - /// Returns the extent of the tensor - TensorCoord extent() const { - return extent_; - } - - /// Returns the extent of the tensor - TensorCoord & extent() { - return extent_; - } - - /// Copies data from device to host - void sync_host() { - if (device_backed()) { - device_memory::copy_to_host( - host_data(), device_data(), imaginary_stride() * 2); - } - } - - /// Copies data from host to device - void sync_device() { - if (device_backed()) { - device_memory::copy_to_device( - device_data(), host_data(), imaginary_stride() * 2); - } - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_device_to_host( - Element const* ptr_device_real, ///< source device memory - Element const* ptr_device_imag, ///< source device memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_to_host( - host_data(), ptr_device_real, count); - - device_memory::copy_to_host( - host_data_imag(), ptr_device_imag, count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_device_to_device( - Element const* ptr_device_real, ///< source device memory - Element const* ptr_device_imag, ///< source device memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_device_to_device( - device_data(), ptr_device_real, count); - - device_memory::copy_device_to_device( - device_data_imag(), ptr_device_imag, count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_host_to_device( - Element const* ptr_host_real, ///< source host memory - Element const* ptr_host_imag, ///< source host memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_to_device( - device_data(), ptr_host_real, count); - - device_memory::copy_to_device( - device_data_imag(), ptr_host_imag, count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_in_host_to_host( - Element const* ptr_host_real, ///< source host memory - Element const* ptr_host_imag, ///< source host memory - LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_host_to_host( - host_data(), ptr_host_real, count); - - device_memory::copy_host_to_host( - host_data_imag(), ptr_host_imag, count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_device_to_host( - Element * ptr_host_real, ///< source device memory - Element * ptr_host_imag, ///< source device memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_to_host( - ptr_host_real, device_data(), count); - - device_memory::copy_to_host( - ptr_host_imag, device_data_imag(), count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_device_to_device( - Element * ptr_device_real, ///< source device memory - Element * ptr_device_imag, ///< source device memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_device_to_device( - ptr_device_real, device_data(), count); - - device_memory::copy_device_to_device( - ptr_device_imag, device_data_imag(), count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_host_to_device( - Element * ptr_device_real, ///< source device memory - Element * ptr_device_imag, ///< source device memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_to_device( - ptr_device_real, host_data(), count); - - device_memory::copy_to_device( - ptr_device_imag, host_data_imag(), count); - } - - /// Copy data from a caller-supplied device pointer into host memory. - void copy_out_host_to_host( - Element * ptr_host_real, ///< source host memory - Element * ptr_host_imag, ///< source host memory - LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. - - if (count < 0) { - count = capacity(); - } - else { - count = __NV_STD_MIN(capacity(), count); - } - - device_memory::copy_host_to_host( - ptr_host_real, host_data(), count); - - device_memory::copy_host_to_host( - ptr_host_imag, host_data_imag(), count); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h deleted file mode 100644 index 9cd62927432c65ce1f0187f46306f7e1198a1182..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h +++ /dev/null @@ -1,157 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief uncompress sparse matrix from the host side -*/ -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/tensor_view.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { - -// uncompress sparse tensor core A matrix -template -void uncompress(TensorRef uncompressed_tensor_a, - TensorRef tensor_a, - TensorRef tensor_e, int row, int col) { - // How many uncompressed data we can get with ElementE meta data - int DecompressedElementsPerElementE = - 256 / cutlass::sizeof_bits::value; - - // Process 4bit meta data a time - int step; - - // 1:2 or 2:4 or 4:8 - int a, b; - - if (cutlass::sizeof_bits::value == 4) { - step = 8; - a = 4; - b = 8; - } else if (cutlass::sizeof_bits::value == 8) { - step = 4; - a = 2; - b = 4; - } else if (cutlass::sizeof_bits::value == 16) { - step = 4; - a = 2; - b = 4; - } else if (cutlass::sizeof_bits::value == 32) { - step = 2; - a = 1; - b = 2; - } - - int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; - - for (int r = 0; r < row; ++r) { - for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { - - ElementE meta = tensor_e.at(MatrixCoord(r, c)); - - for (int i = 0; i < DecompressedElementsPerElementE; i += step) { - int e = (meta >> (i / step * 4)) & 0xf; - int idx0 = e & 0x3; - int idx1 = e >> 2; - - if (a == 1) idx0 = idx0 / 2; - - for (int ii = 0; ii < step; ii += ElementsPerE) { - int real_col = - c * DecompressedElementsPerElementE + i + ii; - int compressed_col = (real_col / b) * a; - - if (ii == (idx0 * ElementsPerE)) { - uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = - tensor_a.at(MatrixCoord(r, compressed_col)); - if (ElementsPerE == 2) - uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = - tensor_a.at(MatrixCoord(r, compressed_col + 1)); - } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { - uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = - tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); - if (ElementsPerE == 2) - uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = - tensor_a.at( - MatrixCoord(r, compressed_col + ElementsPerE + 1)); - } else { - uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = - ElementA(0); - if (ElementsPerE == 2) - uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = - ElementA(0); - } - } - } - } - } -} - -// uncompress ELL block sparse matrix -template -void uncompress_ell_block_sparse( - TensorRef uncompressed_tensor_a, - TensorRef tensor_a, - TensorRef ell_idx, - int rows, int cols, - int ell_num_cols, int ell_blocksize) { - - for (int r = 0; r < rows / ell_blocksize; ++r) { - for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { - - ElementE idx = ell_idx.at(MatrixCoord(r, c)); - - if (idx != -1) { - int row_begin = r * ell_blocksize; - int col_begin_real = idx * ell_blocksize; - int col_begin = c * ell_blocksize; - - for (int i = 0; i < ell_blocksize; ++i) { - for (int j = 0; j < ell_blocksize; ++j) { - uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = - tensor_a.at( - MatrixCoord(row_begin + i, col_begin +j)); - } - } - } - } - } -} - -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h deleted file mode 100644 index 6b72b043fc0c1271cf9f12e5cb9a81d29659cb0a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h +++ /dev/null @@ -1,38 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -// integer_sequence moved to cutlass/numeric_types.h - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp deleted file mode 100644 index 43f5a3f92d29f229703cc4c5f9071c11d0f89df4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ /dev/null @@ -1,472 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Utilities for mixed input data type kernels. -*/ - -#pragma once - -#include -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/arch/mma_sm90.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/device/tensor_fill.h" -#include "cute/util/type_traits.hpp" - -namespace cutlass { - -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -template < - class QuantizedElement, - class DequantizedElement, - class OperandLayout, - class ElementScale, - class ElementZero, - class ScaleBroadCastLayout, - class ThrLayout> -__global__ void dequantize_kernel(DequantizedElement* dq_buffer, - QuantizedElement const* q_buffer, - OperandLayout const operand_layout, - ElementScale const* scale_buffer, - ElementZero const* zero_buffer, - ScaleBroadCastLayout const broadcasted_scale_layout, - ThrLayout thr_layout) { - using namespace cute; - - // Represent the full tensors to gmem elements. - // These are expected to have shape [MN, K, L] - cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); - cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); - // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting - // It is expected that K % G == 0 - cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); - cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); - - // Assign 1 thread per element in the thread block - auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // - auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) - - // Tile across the block - auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); - auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); - auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); - auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); - - auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); - auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); - auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); - auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); - - // Make a fragment of registers to hold gmem loads - cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); - cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); - cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); - cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); - cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); - cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); - - cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); - auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); - auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); - - const auto num_iters = cute::size<3>(tOpDq_gOpDq); - - for (int ii = 0; ii < num_iters; ++ii) { - const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); - if (thread_offset < cute::size<0>(operand_layout)) { - cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); - cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); - cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); - cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); - cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); - cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); - cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); - cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); - cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); - } - } -} - -template < - class QuantizedElement, - class DequantizedElement, - class OperandLayout, - class ElementScale, - class ElementZero, - class ScaleLayout> -static void dequantize(DequantizedElement* dq_buffer, - QuantizedElement const* q_buffer, - OperandLayout const operand_layout, - ElementScale const* scale_buffer, - ElementZero const* zero_buffer, - ScaleLayout const scale_layout, - int const group_size, - cudaStream_t &stream) { - using namespace cute; - - constexpr int tpb = 128; - auto thr_layout = make_layout(make_shape(Int{})); - - const auto num_rows = get<0>(shape(operand_layout)); - const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] - const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] - const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] - - if (num_rows != size<0>(scale_layout)) { - std::cerr << "Invalid first dimension for scales. Must match first dim for weights." - << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) - << std::endl; - exit(-1); - } - - const auto scale_stride0 = get<0>(stride(scale_layout)); - const auto scale_stride1 = get<1>(stride(scale_layout)); - const auto scale_stride2 = get<2>(stride(scale_layout)); - - auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); - auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); - auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); - - const auto blocks_x = gemm_k; - const auto blocks_y = batches; - - dim3 blocks(blocks_x, blocks_y, 1); - dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); - CUDA_CHECK(cudaStreamSynchronize(stream)); -} - -template -class packed_scale_t { -public: - static_assert(cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v, - "only 8 bit arithmetic types are supported."); - CUTLASS_HOST_DEVICE - explicit packed_scale_t(T val) { - if constexpr (!cute::is_unsigned_v) { - // Only pack negative values. The positive values are generated in flight in the mainloop. - storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); - storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); - } - else { - storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); - storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); - } - } - CUTLASS_HOST_DEVICE - packed_scale_t() = default; - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - CUTLASS_HOST_DEVICE - bool operator==(packed_scale_t const& rhs) const { - return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; - } - CUTLASS_HOST_DEVICE - bool operator!=(packed_scale_t const& rhs) const { - return !(*this == rhs); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() + rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() - rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() * rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() / rhs.get()); - } - -private: - using Storage = uint32_t; - using Stage = uint8_t; - - Storage storage[2] {}; - - CUTLASS_HOST_DEVICE - static Storage pack4(T c1, T c2, T c3, T c4) { - Storage result = 0; - result |= (static_cast(reinterpret_cast(c4)) << 24); - result |= (static_cast(reinterpret_cast(c3)) << 16); - result |= (static_cast(reinterpret_cast(c2)) << 8); - result |= static_cast(reinterpret_cast(c1)); - return result; - } - CUTLASS_HOST_DEVICE - T get() const { - auto stage = static_cast(storage[0] >> 8); - #if defined(__CUDA_ARCH__) - return reinterpret_cast(stage); - #else - T tmp; - std::memcpy(&tmp, &stage, sizeof(Stage)); - return tmp; - #endif - } - CUTLASS_HOST_DEVICE - T get(int idx) const { - Stage stage; - if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); - else stage = static_cast(storage[1] >> (8 * idx - 32)); - #if defined(__CUDA_ARCH__) - return reinterpret_cast(stage); - #else - T tmp; - std::memcpy(&tmp, &stage, sizeof(Stage)); - return tmp; - #endif - } -}; - -// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. -// Here the encodings of positive values and negative values are unified (except for the sign bit). -// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). -static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { - - using StorageType = cutlass::int4b_t::Storage; - constexpr int pack = cute::sizeof_bits_v / 4; - const size_t host_buf_size = block_size / pack; - std::vector host_buf(host_buf_size); - cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); - - for (auto&& d : host_buf) { - StorageType out = 0; - StorageType mask = 0x0f; - for (int i = 0; i < pack; i++) { - cutlass::int4b_t curr; - curr.storage = (d >> (i * 4)) & 0x0f; - switch (curr) { - case 1: curr.storage = StorageType(0b0111); break; // 2's complement - case 2: curr.storage = StorageType(0b0110); break; // 2's complement - case 3: curr.storage = StorageType(0b0101); break; // 2's complement - case 4: curr.storage = StorageType(0b0100); break; // 2's complement - case 5: curr.storage = StorageType(0b0011); break; // 2's complement - case 6: curr.storage = StorageType(0b0010); break; // 2's complement - case 7: curr.storage = StorageType(0b0001); break; // 2's complement - default: break; - } - out |= (curr.storage << (4 * i)) & mask; - mask <<= 4; - } - d = out; - } - - cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); - return true; -} - -template -static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { - std::vector data_in(block_size); - std::vector> data_out(block_size); - - try { - cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); - } - catch (cutlass::cuda_exception const& e) { - std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; - return false; - } - - for (size_t i = 0; i < block_size; i++) { - cutlass::packed_scale_t tmp(data_in[i]); - data_out[i] = reinterpret_cast const&>(tmp); - } - - try { - cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); - } - catch (cutlass::cuda_exception const& e) { - std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; - return false; - } - return true; -} - -template -struct UnderlyingElement { - using type = T; -}; - -template -struct UnderlyingElement> { - using type = typename T::Element; -}; - -// Given a type of MMA instruction, compute a memory reordering atom that places all values -// owned by each thread in contiguous memory locations. This improves smem load vectorization, -// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order -// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. -// In addition, we can reorder the values across several MMA instructions to get even wider -// vectorization (AtomLayout parameter) and permute the values within each instruction to get -// more optimal conversion instruction sequences (ValLayout parameter). -template , - class ValLayout = cute::Layout> -constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) -{ - using namespace cute; - - static_assert(is_static_v, "ValLayout must be static"); - static_assert(is_static_v, "AtomLayout must be static"); - - // 1. Choose an MMA atom to access TV layout and MN shape - // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary - using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); - using MmaTraits = MMA_Traits; - auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); - auto tv_layout_mma = typename MmaTraits::ALayout{}; - static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); - - // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) - // Note: this assumes A is partitioned between warps along M mode - auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); - auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); - auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); - auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); - - // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization - auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); - - // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) - auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); - auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); - auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); - auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); - - return layout_atom; -} - -template -__global__ void reorder_tensor_kernel( - cute::Tensor S, - cute::Tensor D, - TiledCopy tiled_copy) -{ - using namespace cute; - - using T = typename EngineDst::value_type; - - Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); - Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); - - auto thread_copy = tiled_copy.get_slice(threadIdx.x); - Tensor tS = thread_copy.partition_S(gS); - Tensor tD = thread_copy.partition_D(gD); - - copy(tiled_copy, tS, tD); -} - -template -void reorder_tensor( - cute::Tensor S, - cute::Tensor D) -{ - using namespace cute; - - using T = typename EngineDst::value_type; - static_assert(is_same_v, T>, "Type mismatch"); - - // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread - // This avoids a race condition when writing out subbyte types (e.g. int4b_t). - auto has_major_mode = [](auto s) { - return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); - }; - static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), - "Could not find stride-1 mode in destination layout"); - constexpr int N = shape_div(Int<8>{}, Int>{}); - auto val_layout = conditional_return(LayoutDst{}))>( - make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), - make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); - - // Make a tiled copy with a simple row-major thread order and above layout - int constexpr NumThreads = 128; - auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); - - // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper - using TileShape = Shape<_16>; - auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); - dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; - - reorder_tensor_kernel<<>>(S, D, tiled_copy); - CUDA_CHECK(cudaDeviceSynchronize()); -} - -// In-place version -template -void reorder_tensor( - T const* src, - LayoutSrc const& layout_src, - T * dst, - LayoutDst const& layout_dst) -{ - using namespace cute; - reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), - make_tensor(make_gmem_ptr(dst), layout_dst)); -} - -// In-place version -template -void reorder_tensor( - T * data, - LayoutSrc const& layout_src, - LayoutDst const& layout_dst) -{ - using namespace cute; - cutlass::DeviceAllocation temp(size(layout_src)); - reorder_tensor(data, layout_src, temp.get(), layout_dst); - cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); -} - -#undef CUDA_CHECK - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp deleted file mode 100644 index 811ba152ab7c6e8fafc1cebdbb3726798fd16b3c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp +++ /dev/null @@ -1,570 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. -*/ - -#pragma once - -#include "cute/layout.hpp" -#include "cute/container/array.hpp" // cute::array -#include "cutlass/conv/convolution.h" // cutlass::conv::Operator - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Strides without batch mode - -template -CUTLASS_HOST_DEVICE -cute::Stride> -make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); - return s_copy; -} - -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT> -make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); - return s_copy; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Strides with batch mode - -template -CUTLASS_HOST_DEVICE -cute::Stride, int64_t> -make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); - int batch_count = cute::get<2>(shape_MKL); - if (batch_count > 1) { - cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); - } - else { - cute::get<2>(s_copy) = static_cast(0); - } - return s_copy; -} - -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, int64_t> -make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); - int batch_count = cute::get<2>(shape_MKL); - if (batch_count > 1) { - cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); - } - else { - cute::get<2>(s_copy) = static_cast(0); - } - return s_copy; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Strides with group mode - -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Int<0>> -make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); - return s_copy; -} - -template -CUTLASS_HOST_DEVICE -cute::Stride, StrideIntT, cute::Int<0>> -make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); - return s_copy; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Strides for convolutions - -// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) -// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order -// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout -// right in KTRSC order and can be coalesced to just k. -// We enforce this condition here with asserts. -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, cute::Int<0>> s, - cute::array shape_output, - cute::array stride_output, - cutlass::conv::Operator conv_op) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - static_assert(RankT_ >= 3u); - constexpr static int RankT = static_cast(RankT_); - - assert(stride_output[RankT-1] == 1); - cute::for_each(cute::make_seq{}, [&](auto i) { - assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); - }); - - auto s_copy = s; - cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? - stride_output[0] : - stride_output[RankT-2]; - return s_copy; -} - -// -// Activation tensor ((w, h, d, n), _1) for fprop kernel -// - -// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Int<1>> -make_cute_packed_stride( - cute::Stride, cute::Int<1>> s, - cute::array stride_nwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - assert(stride_nwc[2] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_nwc[1]; - cute::get<0,1>(s_copy) = stride_nwc[0]; - return s_copy; -} - -// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Int<1>> -make_cute_packed_stride( - cute::Stride, cute::Int<1>> s, - cute::array stride_nhwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - assert(stride_nhwc[3] == 1); - auto s_copy = s; - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<0,i>(s_copy) = stride_nhwc[2-i]; - }); - return s_copy; -} - -// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Int<1>> -make_cute_packed_stride( - cute::Stride, cute::Int<1>> s, - cute::array stride_ndhwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ndhwc[4] == 1); - auto s_copy = s; - cute::for_each(cute::make_seq<4>{}, [&](auto i) { - cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; - }); - return s_copy; -} - -// -// Filter tensor (k, (_1, s, r, t)) for fprop kernel -// - -// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT>> -make_cute_packed_stride( - cute::Stride, IntT>> s, - cute::array stride_ksc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ksc[2] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_ksc[0]; - cute::get<1,1>(s_copy) = stride_ksc[1]; - return s_copy; -} - -// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT>> -make_cute_packed_stride( - cute::Stride, IntT, IntT>> s, - cute::array stride_krsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_krsc[3] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_krsc[0]; - cute::for_each(cute::make_seq<2>{}, [&](auto i) { - cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; - }); - return s_copy; -} - -// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT, IntT>> -make_cute_packed_stride( - cute::Stride, IntT, IntT, IntT>> s, - cute::array stride_ktrsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ktrsc[4] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_ktrsc[0]; - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; - }); - return s_copy; -} - -// -// Activation tensor (_1, (w, h, d, n)) for wgrad kernel -// -// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel -// - -// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad -// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Stride> -make_cute_packed_stride( - cute::Stride, cute::Stride> s, - cute::array stride_nwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_nwc[2] == 1); - auto s_copy = s; - if (ConvOp == cutlass::conv::Operator::kWgrad) { - cute::get<1,0>(s_copy) = stride_nwc[1]; - cute::get<1,1>(s_copy) = stride_nwc[0]; - } - else if (ConvOp == cutlass::conv::Operator::kDgrad) { - // stride_nwc in dgrad is ksc. - cute::get<1,0>(s_copy) = stride_nwc[0]; - cute::get<1,1>(s_copy) = stride_nwc[1]; - } - return s_copy; -} - -// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad -// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Stride> -make_cute_packed_stride( - cute::Stride, cute::Stride> s, - cute::array stride_nhwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_nhwc[3] == 1); - auto s_copy = s; - if (ConvOp == cutlass::conv::Operator::kWgrad) { - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<1,i>(s_copy) = stride_nhwc[2-i]; - }); - } - else if (ConvOp == cutlass::conv::Operator::kDgrad) { - // stride_nhwc in dgrad is krsc. - cute::get<1,0>(s_copy) = stride_nhwc[0]; - cute::for_each(cute::make_seq<2>{}, [&](auto i) { - cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; - }); - } - return s_copy; -} - -// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad -// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad -template -CUTLASS_HOST_DEVICE -cute::Stride, cute::Stride> -make_cute_packed_stride( - cute::Stride, cute::Stride> s, - cute::array stride_ndhwc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ndhwc[4] == 1); - auto s_copy = s; - if (ConvOp == cutlass::conv::Operator::kWgrad) { - cute::for_each(cute::make_seq<4>{}, [&](auto i) { - cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; - }); - } - else if (ConvOp == cutlass::conv::Operator::kDgrad) { - // stride_ndhwc in dgrad is ktrsc. - cute::get<1,0>(s_copy) = stride_ndhwc[0]; - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; - }); - } - return s_copy; -} - -// -// NZPQ tensor (_1, nzpq) for wgrad kernel -// - -// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT> -make_cute_packed_stride( - cute::Stride, IntT> s, - cute::array stride_nqk, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_nqk[2] == 1); - auto s_copy = s; - cute::get<1>(s_copy) = stride_nqk[1]; - return s_copy; -} - -// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT> -make_cute_packed_stride( - cute::Stride, IntT> s, - cute::array stride_npqk, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_npqk[3] == 1); - auto s_copy = s; - cute::get<1>(s_copy) = stride_npqk[2]; - return s_copy; -} - -// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT> -make_cute_packed_stride( - cute::Stride, IntT> s, - cute::array stride_nzpqk, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_nzpqk[4] == 1); - auto s_copy = s; - cute::get<1>(s_copy) = stride_nzpqk[3]; - return s_copy; -} - - - -// -// Wgrad output tensor (k, (_1, s, r, t), _0) -// - -// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT>, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT>, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_ksc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ksc[2] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_ksc[0]; - cute::get<1,1>(s_copy) = stride_ksc[1]; - return s_copy; -} - -// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT>, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT, IntT>, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_krsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_krsc[3] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_krsc[0]; - cute::for_each(cute::make_seq<2>{}, [&](auto i) { - cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; - }); - return s_copy; -} - -// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT, IntT>, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_ktrsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ktrsc[4] == 1); - auto s_copy = s; - cute::get<0,0>(s_copy) = stride_ktrsc[0]; - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; - }); - return s_copy; -} - - -// -// Wgrad output tensor ((_1, s, r, t), k, _0) -// - -// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT>, IntT, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT>, IntT, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_ksc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ksc[2] == 1); - auto s_copy = s; - cute::get<1,0>(s_copy) = stride_ksc[0]; - cute::get<0,1>(s_copy) = stride_ksc[1]; - return s_copy; -} - -// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT>, IntT, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_krsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_krsc[3] == 1); - auto s_copy = s; - cute::get<1,0>(s_copy) = stride_krsc[0]; - cute::for_each(cute::make_seq<2>{}, [&](auto i) { - cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; - }); - return s_copy; -} - -// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) -template -CUTLASS_HOST_DEVICE -cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> -make_cute_packed_stride( - cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, - [[maybe_unused]] cute::array shape_output, - cute::array stride_ktrsc, - conv::Operator ConvOp) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set dynamically. Static strides not supported."); - - assert(stride_ktrsc[4] == 1); - auto s_copy = s; - cute::get<1,0>(s_copy) = stride_ktrsc[0]; - cute::for_each(cute::make_seq<3>{}, [&](auto i) { - cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; - }); - return s_copy; -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp deleted file mode 100644 index c38ad3f710c18e5be1bb7e01dc66d7efcd2646d9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp +++ /dev/null @@ -1,341 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -#include -#include - -#include - -// The computed infinity norm does not include -// any NaN column absolute-value sums. -struct matrix_inf_norm_result { - // Accumulate errors in double, as this is generally - // the highest precision that the examples use. - double inf_norm = 0.0; - bool found_nan = false; -}; - -// In theory, cute::Tensor, T> could be treated as a view type, -// and thus passed by value (as std::span or std::string_view would be). -// However, generic cute::Tensor are more like containers -// and thus are best passed by reference or const reference. -template -matrix_inf_norm_result -matrix_inf_norm(cute::Tensor const& host_matrix) -{ - using error_type = decltype(std::declval().inf_norm); - using element_type = typename EngineType::value_type; - - error_type inf_norm = 0.0; - bool found_nan = false; - - // Computing the infinity norm requires that we be able - // to treat the input as a matrix, with rows and columns. - const int64_t num_rows = cute::size<0>(host_matrix); - const int64_t num_cols = cute::size<1>(host_matrix); - - auto abs_fn = [] (element_type A_ij) { - if constexpr (not std::is_unsigned_v) { - using std::abs; - return abs(A_ij); - } - else { - return A_ij; - } - }; - - for (int64_t i = 0; i < num_rows; ++i) { - error_type row_abs_sum = 0.0; - for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += abs_fn(host_matrix(i, j)); - } - if (std::isnan(row_abs_sum)) { - found_nan = true; - } - else { - inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; - } - } - - return {inf_norm, found_nan}; -} - -// Infinity norm of (X - Y). -template -matrix_inf_norm_result -matrix_diff_inf_norm(cute::Tensor const& X, - cute::Tensor const& Y) -{ - using error_type = decltype(std::declval().inf_norm); - using element_type = typename EngineType::value_type; - - auto abs_fn = [] (element_type A_ij) { - if constexpr (not std::is_unsigned_v) { - using std::abs; - return abs(A_ij); - } - else { - return A_ij; - } - }; - - assert(cute::size<0>(X) == cute::size<0>(Y)); - assert(cute::size<1>(X) == cute::size<1>(Y)); - - // Computing the infinity norm requires that we be able - // to treat the input as a matrix, with rows and columns. - const int64_t num_rows = cute::size<0>(X); - const int64_t num_cols = cute::size<1>(X); - - error_type inf_norm = 0.0; - bool found_nan = false; - - for (int64_t i = 0; i < num_rows; ++i) { - error_type row_abs_sum = 0.0; - for (int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - - element_type(Y(i,j)))); - } - if (std::isnan(row_abs_sum)) { - found_nan = true; - } - else { - inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; - } - } - - return {inf_norm, found_nan}; -} - -template -auto -print_matrix_multiply_mollified_relative_error( - char const A_value_type_name[], - cute::Tensor const& A, - char const B_value_type_name[], - cute::Tensor const& B, - char const C_value_type_name[], - cute::Tensor const& C, - cute::Tensor const& C_ref) -{ - const auto [A_norm, A_has_nan] = matrix_inf_norm(A); - const auto [B_norm, B_has_nan] = matrix_inf_norm(B); - const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); - const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); - - const auto A_norm_times_B_norm = A_norm * B_norm; - const auto relative_error = A_norm_times_B_norm == 0.0 ? - diff_norm : (diff_norm / A_norm_times_B_norm); - - // For expected error bounds, please refer to the LAPACK Users' Guide, - // in particular https://netlib.org/lapack/lug/node108.html . - // Printing the infinity norm of C is a way to check - // that both the function being tested (C) - // and the reference implementation (C_ref) - // don't just do nothing (or fill with zeros). - using std::cout; - using cute::shape; - cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' - << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' - << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' - << std::scientific - << "Infinity norm of A: " << A_norm << '\n' - << "Infinity norm of B: " << B_norm << '\n' - << "Infinity norm of C: " << C_norm << '\n' - << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; - - if(A_norm_times_B_norm == 0.0) { - cout << "Mollified relative error: " << relative_error << '\n'; - } else { - cout << "Relative error: " << relative_error << '\n'; - } - - if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { - cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; - } - return relative_error; -} - -template -auto -print_matrix_multiply_mollified_relative_error( - const char value_type_name[], - const cute::Tensor& A, - const cute::Tensor& B, - const cute::Tensor& C_computed, - const cute::Tensor& C_expected) -{ - return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, - value_type_name, C_computed, C_expected); -} - -// Take a CUTLASS HostTensor (or the like) as input, -// and return a const CuTe Tensor. -// This is useful for use with the above error printing functions. -// This implicitly "transposes" if the layout is RowMajor. -// Note that the HostTensor must be captured by nonconst reference -// in order for X.host_ref().data() to compile. -// (CUTLASS is a bit more container-y than CuTe.) -template -auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) -{ - // The tensors were created with post-transposed extents. - const auto extents = X.extent(); - const auto shape = cute::Shape{extents[0], extents[1]}; - // Both RowMajor and ColumnMajor only store one stride. - const int LDX = X.stride(0); - const auto strides = [&]() { - using input_layout_type = typename std::decay_t::Layout; - if constexpr (std::is_same_v) { - return cute::Stride{1, LDX}; - } - else { - static_assert(std::is_same_v); - return cute::Stride{LDX, 1}; - } - }(); - const auto layout = cute::make_layout(shape, strides); - auto X_data = X.host_ref().data(); - auto X_data_const = const_cast >(X_data); - return cute::make_tensor(X_data_const, layout); -}; - - -// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE. -// This makes the return value suitable as the return value of main(). -template -int -print_relative_error( - std::size_t n, - T1 const& data, - T2 const& reference, - bool print_verbose = false, - bool print_error = true, - double error_margin = 0.00001) { - using std::abs; using std::sqrt; - - // Use either double or complex for error computation - using value_type = cute::remove_cvref_t; - using error_type = std::conditional_t::value, - cute::complex, - double>; - - if (print_verbose) { - std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; - } - - double eps = 1e-200; - - double tot_error_sq = 0; - double tot_norm_sq = 0; - double tot_ind_rel_err = 0; - double max_ind_rel_err = 0; - double max_diff = 0; - for (std::size_t i = 0; i < n; ++i) { - error_type val = data[i]; - error_type ref = reference[i]; - - double aref = abs(ref); - double diff = abs(ref - val); - double rel_error = diff / (aref + eps); - - // Individual relative error - tot_ind_rel_err += rel_error; - - // Maximum relative error - max_ind_rel_err = std::max(max_ind_rel_err, rel_error); - - // Maximum delta in value error - max_diff = std::max(max_diff, diff); - - // Total relative error - tot_error_sq += diff * diff; - tot_norm_sq += aref * aref; - - if (print_verbose) { - std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; - } - } - - double ave_rel_err = tot_ind_rel_err / double(n); - if (print_error) { - printf("Average relative error: %.3e\n", ave_rel_err); - } - - if (print_error) { - printf("Maximum relative error: %.3e\n", max_ind_rel_err); - } - - if (print_error) { - printf("Maximum difference : %.3e\n", max_diff); - } - - double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); - if (print_error) { - printf("Vector relative error: %.3e\n", tot_rel_err); - } - - printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq)); - - return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE; -} - -// Overload for cute::Tensor<> -template -int -print_relative_error( - cute::Tensor data, - cute::Tensor reference, - bool print_verbose = false, - bool print_error = true, - double error_margin = 0.00001) { - assert(size(data) == size(reference)); - return print_relative_error(static_cast(size(data)), - data, reference, - print_verbose, print_error, error_margin); -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h deleted file mode 100644 index 8167c91bf2330d160a78ba210449357b395964ca..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h +++ /dev/null @@ -1,135 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in host-side code. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -namespace cutlass { -namespace reference { -namespace detail { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template function to compute an inner product. -#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a - // host-only type -template -CUTLASS_HOST_DEVICE -Ctype inner_product(Atype a, Btype b, Ctype c) { - return Ctype(a) * Ctype(b) + c; -} - -/// Specialization for matrix multiplication with binary operands -template <> -CUTLASS_HOST_DEVICE -int inner_product, Array, int>( - Array a, - Array b, - int c) { - - int accum = 0; - for (int bit = 0; bit < 32; bit++) { - accum += a[bit] ^ b[bit]; - } - return accum + c; -} - -/* -/// Specialization for matrix multiplication with signed 4-bit integer operands -template <> -CUTLASS_HOST_DEVICE -int inner_product, Array, int>( - Array a, - Array b, - int c) { - - int accum = 0; - for (int k = 0; k < 8; k++) { - accum += a[k] * b[k]; - } - return accum + c; -} - -/// Specialization for matrix multiplication with unsigned 4-bit integer operands -template <> -CUTLASS_HOST_DEVICE -int inner_product, Array, int>( - Array a, - Array b, - int c) { - - int accum = 0; - for (int k = 0; k < 8; k++) { - accum += a[k] * b[k]; - } - return accum + c; -} -*/ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Cast { - // Default behavior: convert to the destination type -#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a - // host-only type - CUTLASS_HOST_DEVICE - static DstType apply(SrcType src) { return static_cast(src); }; -}; - -template <> -struct Cast { - CUTLASS_HOST_DEVICE - static int8_t apply(float src) { - // Clamp to the range of signed 8-bit integers. - return static_cast(fmaxf(-128.f, fminf(127.f, src))); - }; -}; - -template <> -struct Cast { - CUTLASS_HOST_DEVICE - static uint8_t apply(float src) { - // Clamp to the range of signed 8-bit integers. - return static_cast(fmaxf(0.f, fminf(255.f, src))); - }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail -} // namespace reference -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h deleted file mode 100644 index 652d622586cb202ecfe69ac892978b649b5d1be7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h +++ /dev/null @@ -1,94 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in host-side code. -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LinearToCoordinateHelper { - - CUTLASS_HOST_DEVICE - void operator()(Coord &coord, int64_t idx, Coord const &extent) const { - - int64_t prod = 1; - - CUTLASS_PRAGMA_UNROLL - for (int i = Rank - Index; i < Rank; ++i) { - prod *= int64_t(extent[i]); - } - - coord[Rank - Index - 1] = int(idx / prod); - - int64_t residual = idx % prod; - LinearToCoordinateHelper()(coord, residual, extent); - } -}; - -template -struct LinearToCoordinateHelper { - - CUTLASS_HOST_DEVICE - void operator()(Coord &coord, int64_t idx, Coord const &) const { - coord[Rank - 1] = int(idx); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LinearToCoordinate { - - CUTLASS_HOST_DEVICE - void operator()(Coord &coord, int64_t idx, Coord const &extent) const { - LinearToCoordinateHelper()(coord, idx, extent); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail -} // namespace reference -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h deleted file mode 100644 index 7c6f803c47f5c407cf058d40bc8274a448a36dc4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h +++ /dev/null @@ -1,1549 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Reference implementation for convolution in device-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/functional.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" - -namespace cutlass { -namespace reference { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Conv2d device reference kernel -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2d Fprop kernel - y = fprop(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 16, // shape of a threadblock in units of threads - int kCtaShapeN = 8 // shape of a threadblock in units of threads -> -__global__ void Conv2dFprop( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_n[kThreadM]; - int thread_p[kThreadM]; - int thread_q[kThreadM]; - - // Compute N, P, Q coordinates for each row of a thread's tile - int64_t PQ = int64_t(problem_size.P) * problem_size.Q; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int64_t npq = npq_start + m; - - thread_n[m] = int(npq / PQ); - - int64_t residual = npq % PQ; - thread_p[m] = int(residual / problem_size.Q); - thread_q[m] = int(residual % problem_size.Q); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - int c_per_group = problem_size.C / problem_size.groups; - int k_per_group = problem_size.K / problem_size.groups; - - // Compute convolution - for (int R = 0; R < problem_size.R; ++R) { - for (int S = 0; S < problem_size.S; ++S) { - for (int C = 0; C < problem_size.C; ++C) { - - // Get group id of currnet channel - int c_group_idx = C / c_per_group; - - // Load from activations tensor - int filter_r = R; - int filter_s = S; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - R; - filter_s = problem_size.S - 1 - S; - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { - element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); - } - else { - element_A[m] = ElementAccumulator(); - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_k = k_start + n; - int k_group_idx = thread_k / k_per_group; - - if (thread_k < problem_size.K && k_group_idx == c_group_idx) { - element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); - } - else { - element_B[n] = ElementAccumulator(); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - } - } - } - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_k = k_start + n; - if (thread_k < problem_size.K) { - - ElementCompute c_ref = ElementCompute(); - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); - } - - tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } - } - } -} - -// Conv3d Fprop kernel - y = fprop(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 16, // shape of a threadblock in units of threads - int kCtaShapeN = 8 // shape of a threadblock in units of threads -> -__global__ void Conv3dFprop( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_n[kThreadM]; - int thread_z[kThreadM]; - int thread_p[kThreadM]; - int thread_q[kThreadM]; - - // Compute N, Z, P, Q coordinates for each row of a thread's tile - int64_t PQ = int64_t(problem_size.P) * problem_size.Q; - int64_t ZPQ = PQ * problem_size.Z; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int64_t nzpq = nzpq_start + m; - - thread_n[m] = int(nzpq / ZPQ); - - int64_t residual = nzpq % ZPQ; - thread_z[m] = int(residual / PQ); - - residual = residual % PQ; - thread_p[m] = int(residual / problem_size.Q); - thread_q[m] = int(residual % problem_size.Q); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - // Compute convolution - for (int T = 0; T < problem_size.T; ++T) { - for (int R = 0; R < problem_size.R; ++R) { - for (int S = 0; S < problem_size.S; ++S) { - for (int C = 0; C < problem_size.C; ++C) { - - // Load from activations tensor - int filter_t = T; - int filter_r = R; - int filter_s = S; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - T; - filter_r = problem_size.R - 1 - R; - filter_s = problem_size.S - 1 - S; - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; - int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - if (thread_n[m] < problem_size.N && - d >= 0 && d < problem_size.D && - h >= 0 && h < problem_size.H && - w >= 0 && w < problem_size.W) { - - element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); - } - else { - element_A[m] = ElementAccumulator(); - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_k = k_start + n; - - if (thread_k < problem_size.K) { - element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); - } - else { - element_B[n] = ElementAccumulator(); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - - } // for (C) - } // for (S) - } // for (R) - } // for (T) - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - if (thread_n[m] < problem_size.N && - thread_z[m] < problem_size.Z && - thread_p[m] < problem_size.P && - thread_q[m] < problem_size.Q) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_k = k_start + n; - if (thread_k < problem_size.K) { - - ElementCompute c_ref = ElementCompute(); - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); - } - - tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } // for (n) - - } - } // for (m) -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2d dgrad kernel - dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 16, // shape of a threadblock in units of threads - int kCtaShapeN = 8 // shape of a threadblock in units of threads -> -__global__ void Conv2dDgrad( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_n[kThreadM]; - int thread_h[kThreadM]; - int thread_w[kThreadM]; - - // Compute N, H, W coordinates for each row of a thread's tile - int64_t HW = int64_t(problem_size.H) * problem_size.W; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int64_t nhw = nhw_start + m; - - thread_n[m] = int(nhw / HW); - - int64_t residual = nhw % HW; - thread_h[m] = int(residual / problem_size.W); - thread_w[m] = int(residual % problem_size.W); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - // Compute convolution - for (int R = 0; R < problem_size.R; ++R) { - for (int S = 0; S < problem_size.S; ++S) { - for (int K = 0; K < problem_size.K; ++K) { - - // Load from activations tensor - int filter_r = R; - int filter_s = S; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - R; - filter_s = problem_size.S - 1 - S; - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; - int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; - - element_A[m] = ElementAccumulator(); - - if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { - - p = p / problem_size.stride_h; - q = q / problem_size.stride_w; - - if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { - element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); - } - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_c = c_start + n; - - if (thread_c < problem_size.C) { - element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); - } - else { - element_B[n] = ElementAccumulator(); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - } - } - } - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_c = c_start + n; - if (thread_c < problem_size.C) { - - ElementCompute c_ref = ElementCompute(); - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); - } - - tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } - } - } -} - -// Conv3d dgrad kernel - dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 16, // shape of a threadblock in units of threads - int kCtaShapeN = 8 // shape of a threadblock in units of threads -> -__global__ void Conv3dDgrad( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_n[kThreadM]; - int thread_d[kThreadM]; - int thread_h[kThreadM]; - int thread_w[kThreadM]; - - // Compute N, H, W coordinates for each row of a thread's tile - int64_t HW = int64_t(problem_size.H) * problem_size.W; - int64_t DHW = HW * problem_size.D; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int64_t ndhw = ndhw_start + m; - - thread_n[m] = int(ndhw / DHW); - - int64_t residual = ndhw % DHW; - thread_d[m] = int(residual / HW); - - residual = residual % HW; - thread_h[m] = int(residual / problem_size.W); - thread_w[m] = int(residual % problem_size.W); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - // Compute convolution - for (int T = 0; T < problem_size.T; ++T) { - for (int R = 0; R < problem_size.R; ++R) { - for (int S = 0; S < problem_size.S; ++S) { - for (int K = 0; K < problem_size.K; ++K) { - - // Load from activations tensor - int filter_t = T; - int filter_r = R; - int filter_s = S; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - T; - filter_r = problem_size.R - 1 - R; - filter_s = problem_size.S - 1 - S; - } - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; - int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; - int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; - - element_A[m] = ElementAccumulator(); - - if (z >= 0 && !(z % problem_size.stride_d) && - p >= 0 && !(p % problem_size.stride_h) && - q >= 0 && !(q % problem_size.stride_w)) { - - z = z / problem_size.stride_d; - p = p / problem_size.stride_h; - q = q / problem_size.stride_w; - - if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { - element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); - } - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_c = c_start + n; - - if (thread_c < problem_size.C) { - element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); - } - else { - element_B[n] = ElementAccumulator(); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - - } // for (C) - } // for (S) - } // for (R) - } // for (T) - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - - if (thread_n[m] < problem_size.N && - thread_d[m] < problem_size.D && - thread_h[m] < problem_size.H && - thread_w[m] < problem_size.W) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - int thread_c = c_start + n; - if (thread_c < problem_size.C) { - - ElementCompute c_ref = ElementCompute(); - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); - } - - tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// Conv2d wgrad kernel - dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 8, // shape of a threadblock in units of threads - int kCtaShapeN = 16 // shape of a threadblock in units of threads -> -__global__ void Conv2dWgrad( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_r[kThreadN]; - int thread_s[kThreadN]; - int thread_c[kThreadN]; - - // Compute R, S, C coordinates for each row of a thread's tile - int64_t SC = int64_t(problem_size.S) * problem_size.C; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - int64_t rsc = rsc_start + n; - int64_t residual = rsc % SC; - - thread_r[n] = int(rsc / SC); - thread_s[n] = int(residual / problem_size.C); - thread_c[n] = int(residual % problem_size.C); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - // Compute convolution - for (int N = 0; N < problem_size.N; ++N) { - for (int P = 0; P < problem_size.P; ++P) { - for (int Q = 0; Q < problem_size.Q; ++Q) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int thread_k = k_start + m; - - element_A[m] = ElementAccumulator(); - - if (thread_k < problem_size.K) { - element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - // Load from activations tensor - int filter_r = thread_r[n]; - int filter_s = thread_s[n]; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - filter_r; - filter_s = problem_size.S - 1 - filter_s; - } - - int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - element_B[n] = ElementAccumulator(); - - if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { - element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - } - } - } - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int thread_k = k_start + m; - - if (thread_k < problem_size.K) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { - - ElementCompute c_ref = ElementCompute(); - - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); - } - - tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } - } - } -} - -// Conv3d wgrad kernel - dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension - int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension - int kCtaShapeM = 8, // shape of a threadblock in units of threads - int kCtaShapeN = 16 // shape of a threadblock in units of threads -> -__global__ void Conv3dWgrad( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta - ) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - ElementAccumulator element_A[kThreadM]; - ElementAccumulator element_B[kThreadN]; - ElementAccumulator accum[kThreadM][kThreadN]; - - int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; - int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; - - int thread_t[kThreadN]; - int thread_r[kThreadN]; - int thread_s[kThreadN]; - int thread_c[kThreadN]; - - // Compute R, S, C coordinates for each row of a thread's tile - int64_t SC = int64_t(problem_size.S) * problem_size.C; - int64_t RSC = SC * problem_size.R; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - int64_t trsc = trsc_start + n; - - thread_t[n] = int(trsc / RSC); - - int64_t residual = trsc % RSC; - thread_r[n] = int(residual / SC); - - residual = residual % SC; - thread_s[n] = int(residual / problem_size.C); - thread_c[n] = int(residual % problem_size.C); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = ElementAccumulator(); - } - } - - // Compute convolution - for (int N = 0; N < problem_size.N; ++N) { - for (int Z = 0; Z < problem_size.Z; ++Z) { - for (int P = 0; P < problem_size.P; ++P) { - for (int Q = 0; Q < problem_size.Q; ++Q) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int thread_k = k_start + m; - - element_A[m] = ElementAccumulator(); - - if (thread_k < problem_size.K) { - element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); - } - } - - // Load from filters tensor - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - // Load from activations tensor - int filter_t = thread_t[n]; - int filter_r = thread_r[n]; - int filter_s = thread_s[n]; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - filter_t; - filter_r = problem_size.R - 1 - filter_r; - filter_s = problem_size.S - 1 - filter_s; - } - - int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; - int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - element_B[n] = ElementAccumulator(); - - if (d >= 0 && d < problem_size.D && - h >= 0 && h < problem_size.H && - w >= 0 && w < problem_size.W && - thread_c[n] < problem_size.C) { - - element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); - } - } - - // Accumulate matrix product - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); - } - } - - } // for (Q) - } // for (P) - } // for (Z) - } // for (N) - - // Write out the results - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < kThreadM; ++m) { - int thread_k = k_start + m; - - if (thread_k < problem_size.K) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < kThreadN; ++n) { - - if (thread_t[n] < problem_size.T && - thread_r[n] < problem_size.R && - thread_s[n] < problem_size.S && - thread_c[n] < problem_size.C) { - - ElementCompute c_ref = ElementCompute(); - - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); - } - - tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Conv2d Fprop dispatcher - y = fprop(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv2dFprop( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 16; // shape of a threadblock in units of threads - int const kCtaShapeN = 8; // shape of a threadblock in units of threads - - int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; - int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); - - kernel::Conv2dFprop< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_x, - tensor_w, - tensor_y_in, - tensor_y_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -/// Conv3d Fprop dispatcher - y = fprop(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv3dFprop( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 16; // shape of a threadblock in units of threads - int const kCtaShapeN = 8; // shape of a threadblock in units of threads - - int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; - int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); - - kernel::Conv3dFprop< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_x, - tensor_w, - tensor_y_in, - tensor_y_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv2dDgrad( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 16; // shape of a threadblock in units of threads - int const kCtaShapeN = 8; // shape of a threadblock in units of threads - - int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; - int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); - - kernel::Conv2dDgrad< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_dy, - tensor_w, - tensor_dx_in, - tensor_dx_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv3dDgrad( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 16; // shape of a threadblock in units of threads - int const kCtaShapeN = 8; // shape of a threadblock in units of threads - - int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; - int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); - - kernel::Conv3dDgrad< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_dy, - tensor_w, - tensor_dx_in, - tensor_dx_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv2dWgrad( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 8; // shape of a threadblock in units of threads - int const kCtaShapeN = 16; // shape of a threadblock in units of threads - - int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; - int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); - - kernel::Conv2dWgrad< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_dy, - tensor_x, - tensor_dw_in, - tensor_dw_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv3dWgrad( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - // - // Blocking factors improve performance of reference implementation - // - - int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension - int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension - int const kCtaShapeM = 8; // shape of a threadblock in units of threads - int const kCtaShapeN = 16; // shape of a threadblock in units of threads - - int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; - int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); - - dim3 block(kCtaShapeM, kCtaShapeN); - dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); - - kernel::Conv3dWgrad< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, - InnerProductOp, - kThreadM, - kThreadN, - kCtaShapeM, - kCtaShapeN - ><<< grid, block, 0, stream >>>( - problem_size, - tensor_dy, - tensor_x, - tensor_dw_in, - tensor_dw_out, - alpha, - beta - ); - - cudaError_t result = cudaPeekAtLastError(); - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - - return Status::kSuccess; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv2d( - conv::Operator convolutional_operator, - conv::Conv2dProblemSize problem_size, - TensorRef tensor_A, - TensorRef tensor_B, - TensorRef tensor_C, - TensorRef tensor_D, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - switch (convolutional_operator) { - case conv::Operator::kFprop: - return Conv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - break; - - case conv::Operator::kDgrad: - return Conv2dDgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - break; - - case conv::Operator::kWgrad: - return Conv2dWgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - break; - - default: break; - } - - return Status::kErrorNotSupported; -} - -/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -Status Conv3d( - conv::Operator convolutional_operator, - conv::Conv3dProblemSize problem_size, - TensorRef tensor_A, - TensorRef tensor_B, - TensorRef tensor_C, - TensorRef tensor_D, - ElementCompute alpha, - ElementCompute beta, - cudaStream_t stream = nullptr) { - - switch (convolutional_operator) { - case conv::Operator::kFprop: - return Conv3dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - - case conv::Operator::kDgrad: - return Conv3dDgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - - case conv::Operator::kWgrad: - return Conv3dWgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); - - default: break; - } - - return Status::kErrorNotSupported; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h deleted file mode 100644 index 7d575d522c1dd87d51f9bc58d09786393c5cfea3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h +++ /dev/null @@ -1,385 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in device-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/util/reference/device/kernel/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename AccumulatorType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_gemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - AccumulatorType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - // Blocking structure potentially improves performance of reference implementation - // with a minor increase in complexity. - // - // Note, this reference implementation is NOT expected to approach peak performance. - using OutputTile = MatrixShape<4, 4>; - - dim3 block(16, 8); - - dim3 grid( - (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), - (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) - ); - - // Launch a GEMM kernel - kernel::Gemm< - TensorRef, - TensorRef, - TensorRef, - ScalarType, - AccumulatorType, - OutputTile, - InnerProductOp, - ConvertOp - ><<< grid, block >>>( - problem_size, - alpha, - tensor_a, - tensor_b, - beta, - tensor_c, - tensor_d, - initial_accum - ); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename AccumulatorType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_gemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - AccumulatorType initial_accum) { - - compute_gemm( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, - initial_accum); -} - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename AccumulatorType, - typename InnerProductOp = cutlass::arch::OpMultiplyAdd -> -struct Gemm; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - AccumulatorType initial_accum = AccumulatorType(0)) { - - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - AccumulatorType initial_accum = AccumulatorType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add-saturate -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - AccumulatorType initial_accum = AccumulatorType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm, - NumericConverterClamp>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - AccumulatorType initial_accum = AccumulatorType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm, - NumericConverterClamp>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for XOR-popc -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - AccumulatorType initial_accum = AccumulatorType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - AccumulatorType initial_accum = AccumulatorType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Batched GEMM -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a batch of GEMMs over a set of matrices of common dimension. -// -// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -// -template < - typename TensorRefCollectionA, - typename TensorRefCollectionB, - typename TensorRefCollectionC, - typename ScalarType, - typename AccumulatorType, - typename InnerProductOp, - typename ConvertOp -> -void BatchedGemm( - gemm::GemmCoord problem_size, - int batch_count, - ScalarType alpha, - TensorRefCollectionA const& tensor_a, - TensorRefCollectionB const& tensor_b, - ScalarType beta, - TensorRefCollectionC &tensor_c, - AccumulatorType initial_accum) { - - static_assert( - TensorRefCollectionA::kRank == 2 && - TensorRefCollectionB::kRank == 2 && - TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); - - // Blocking structure potentially improves performance of reference implementation - // with a minor increase in complexity. - // - // Note, this reference implementation is NOT expected to approach peak performance. - using OutputTile = MatrixShape<4, 4>; - - dim3 block(16, 8); - dim3 grid( - (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), - (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), - batch_count - ); - - // Launch a GEMM kernel - kernel::BatchedGemm< - TensorRefCollectionA, - TensorRefCollectionB, - TensorRefCollectionC, - ScalarType, - AccumulatorType, - OutputTile, - InnerProductOp, - ConvertOp - ><<< grid, block >>>( - problem_size, - alpha, - tensor_a, - tensor_b, - beta, - tensor_c, - initial_accum - ); -} - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -// -// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -// -template < - typename TensorRefCollectionA, - typename TensorRefCollectionB, - typename TensorRefCollectionC, - typename ScalarType, - typename AccumulatorType -> -void BatchedGemm( - gemm::GemmCoord problem_size, - int batch_count, - ScalarType alpha, - TensorRefCollectionA const& tensor_a, - TensorRefCollectionB const& tensor_b, - ScalarType beta, - TensorRefCollectionC &tensor_c) { - - BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h deleted file mode 100644 index bddf596214da62a7aa3177f758db3710dc1d2516..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h +++ /dev/null @@ -1,350 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued GEMM in device-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kMblock = 4, - int kNblock = 4 -> -__global__ void GemmComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; - int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; - int batch_idx = blockIdx.z; - - tensor_a.add_pointer_offset(batch_idx * batch_stride_A); - tensor_b.add_pointer_offset(batch_idx * batch_stride_B); - tensor_c.add_pointer_offset(batch_idx * batch_stride_C); - tensor_d.add_pointer_offset(batch_idx * batch_stride_D); - - for (; batch_idx < batch_count; batch_idx += gridDim.z) { - - // Compute matrix product using blocks - ComputeType accum[kMblock][kNblock]; - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b = tensor_b.at(MatrixCoord(k_block, col)); - - ComputeType a_ik = ComputeType(a); - ComputeType b_kj = ComputeType(b); - - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } - - if (transform_b == ComplexTransform::kConjugate) { - b_kj = conj(b_kj); - } - - accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); - } - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); - } - } - } - - tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); - tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); - tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); - tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); - - } // for (batch_idx) -} - -} // namespace kernel - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void GemmComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - int const kMblock = 4; - int const kNblock = 4; - - dim3 block(16, 8); - dim3 grid( - (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), - (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), - batch_count % std::numeric_limits::max() - ); - - if (grid.y <= std::numeric_limits::max()) { - kernel::GemmComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ScalarType, - ComputeType, - ElementD, - ConvertOp, - InnerProductOp, - kMblock, - kNblock - ><<< grid, block >>>( - problem_size, - alpha, - tensor_a, - transform_a, - tensor_b, - transform_b, - beta, - tensor_c, - tensor_d, - initial_accum, - batch_count, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D - ); - } else { - // Using bigger thread tile size - int const kBigMblock = 4; - int const kBigNblock = 16; - - dim3 Bigblock(16, 8); - dim3 Biggrid( - (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), - (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), - batch_count % std::numeric_limits::max() - ); - - kernel::GemmComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ScalarType, - ComputeType, - ElementD, - ConvertOp, - InnerProductOp, - kBigMblock, - kBigNblock - ><<< Biggrid, Bigblock >>>( - problem_size, - alpha, - tensor_a, - transform_a, - tensor_b, - transform_b, - beta, - tensor_c, - tensor_d, - initial_accum, - batch_count, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D - ); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ElementD = ElementC -> -void GemmComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d) { - - GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h deleted file mode 100644 index 48819cf6eaa565b3ec41dbbf78ae244666fd8a65..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +++ /dev/null @@ -1,311 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued GEMM in device code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/complex.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_ref_planar_complex.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static int const kGemmPlanarComplexBlockSize = 4; - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add> -> -__global__ void GemmPlanarComplex( - gemm::GemmCoord problem_size, - complex alpha, - TensorRefPlanarComplex tensor_a, - ComplexTransform transform_a, - TensorRefPlanarComplex tensor_b, - ComplexTransform transform_b, - complex beta, - TensorRefPlanarComplex tensor_c, - TensorRefPlanarComplex tensor_d, - complex initial_accum) { - - int const kMblock = kGemmPlanarComplexBlockSize; - int const kNblock = kGemmPlanarComplexBlockSize; - - using ComplexA = typename TensorRefPlanarComplex::ComplexElement; - using ComplexB = typename TensorRefPlanarComplex::ComplexElement; - using ComplexC = typename TensorRefPlanarComplex::ComplexElement; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - complex accum[kMblock][kNblock]; - - int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; - int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - accum[i][j] = initial_accum; - } - } - - CUTLASS_PRAGMA_NO_UNROLL - for (int k_block = 0; k_block < K; ++k_block) { - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - - ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); - ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); - - complex a = complex{ - ComputeType(a_ik.real()), - ComputeType(a_ik.imag()) - }; - - complex b = complex{ - ComputeType(b_kj.real()), - ComputeType(b_kj.imag()) - }; - - if (transform_a == ComplexTransform::kConjugate) { - a = conj(a); - } - - if (transform_b == ComplexTransform::kConjugate) { - b = conj(b); - } - - accum[i][j] = inner_product_op(a, b, accum[i][j]); - } - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - - complex acc{ - ScalarType(accum[i][j].real()), - ScalarType(accum[i][j].imag()) - }; - - ComplexC c_ij = ComplexC(); - - if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { - c_ij = tensor_c.at(coord); - } - - complex src{ - ScalarType(c_ij.real()), - ScalarType(c_ij.imag()) - }; - - complex result = alpha * acc + beta * src; - - ComplexC d_ij; - - d_ij.real() = convert_op(result.real()); - d_ij.imag() = convert_op(result.imag()); - - tensor_d.at(coord) = d_ij; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add> -> -void GemmPlanarComplex( - gemm::GemmCoord problem_size, - complex alpha, - TensorRefPlanarComplex tensor_a, - ComplexTransform transform_a, - TensorRefPlanarComplex tensor_b, - ComplexTransform transform_b, - complex beta, - TensorRefPlanarComplex tensor_c, - TensorRefPlanarComplex tensor_d, - complex initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - int const kMblock = kernel::kGemmPlanarComplexBlockSize; - int const kNblock = kernel::kGemmPlanarComplexBlockSize; - - dim3 block(16, 8); - - dim3 grid( - (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), - (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), - 1); - - kernel::GemmPlanarComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ScalarType, - ComputeType, - ConvertOp, - InnerProductOp - ><<< grid, block >>>( - problem_size, - alpha, - tensor_a, - transform_a, - tensor_b, - transform_b, - beta, - tensor_c, - tensor_d, - initial_accum - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType -> -void GemmPlanarComplex( - gemm::GemmCoord problem_size, - complex alpha, - TensorRefPlanarComplex tensor_a, - ComplexTransform transform_a, - TensorRefPlanarComplex tensor_b, - ComplexTransform transform_b, - complex beta, - TensorRefPlanarComplex tensor_c, - TensorRefPlanarComplex tensor_d) { - - GemmPlanarComplex( - problem_size, - alpha, - tensor_a, transform_a, - tensor_b, transform_b, - beta, - tensor_c, - tensor_d, - complex()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp deleted file mode 100644 index 497a257d170c411d891942f62fa2c960453d03d5..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp +++ /dev/null @@ -1,146 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief GETT device reference code -*/ -#pragma once - -#include - -namespace cutlass::reference::device { - -template < - class ATensor, - class BTensor, - class CTensor, - class DTensor, - class ElementAccumulator, - class ElementEpilogue> -__global__ static -void -gett_kernel( - DTensor D, - ATensor const A, - BTensor const B, - CTensor const C, - ElementEpilogue alpha, ElementEpilogue beta, - ElementAccumulator acc_init) -{ - using namespace cute; - - static_assert(DTensor::rank == 3, "(M,N,L)"); - static_assert(ATensor::rank == 3, "(M,K,L)"); - static_assert(BTensor::rank == 3, "(N,K,L)"); - static_assert(CTensor::rank == 3, "(M,N,L)"); - - assert(size<0>(A) == size<0>(D)); // M - assert(size<0>(C) == size<0>(D)); // M - assert(size<0>(B) == size<1>(D)); // N - assert(size<1>(C) == size<1>(D)); // N - assert(size<1>(A) == size<1>(B)); // K - assert(size<2>(A) == size<2>(D)); // L - assert(size<2>(B) == size<2>(D)); // L - assert(size<2>(C) == size<2>(D)); // L - - NumericConverter a_converter; - NumericConverter b_converter; - NumericConverter acc_converter; - NumericConverter source_converter; - NumericConverter output_converter; - - // Thread id to each element of D - for (int tid = threadIdx.x + blockDim.x * blockIdx.x; - tid < size(D); - tid += blockDim.x * gridDim.x) { - // (m,n,l) coordinate - auto mnl_coord = idx2crd(tid, product_each(shape(D))); - auto m = get<0>(mnl_coord); - auto n = get<1>(mnl_coord); - auto l = get<2>(mnl_coord); - - auto A_ml = A(m,_,l); - auto B_nl = B(n,_,l); - - ElementAccumulator accum = ElementAccumulator(0); - for (int k = 0; k < size<1>(A); ++k) { - ElementAccumulator a = a_converter(A_ml(k)); - ElementAccumulator b = b_converter(B_nl(k)); - accum += a * b; - } - - ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); - D(m,n,l) = output_converter(scaled_output); - } -} - -// Most general version -template < - class ProblemShapeMNKL, - class ElementA, - class StrideA, - class ElementB, - class StrideB, - class ElementAccumulator, - class ElementC, - class StrideC, - class ElementD, - class StrideD, - class ElementEpilogue> -void -gett( - ProblemShapeMNKL problem_shape_mnkl, - ElementA const* ptr_A, StrideA stride_a_mkl, - ElementB const* ptr_B, StrideB stride_b_nkl, - ElementAccumulator _, - ElementC const* ptr_C, StrideC stride_c_mnl, - ElementD * ptr_D, StrideD stride_d_mnl, - ElementEpilogue alpha, ElementEpilogue beta, - cudaStream_t stream = 0) { - using namespace cute; - - static_assert(cute::rank(ProblemShapeMNKL{}) == 4); - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto K = get<2>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - // Represent the full tensors - auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) - auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) - auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) - auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) - - dim3 dimBlock(256); - dim3 dimGrid(240); - gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); -} - -} // namespace cutlass::reference::device diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h deleted file mode 100644 index 6e131126a336420a2b0e843e3ead3d89fce637fa..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +++ /dev/null @@ -1,162 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/util/reference/device/thread/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { -namespace kernel { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename TensorRefA, - typename TensorRefB, - typename TensorRefC, - typename ScalarType, - typename AccumulatorType, - typename OutputTile, - typename InnerProductOp, - typename ConvertOp -> -__global__ void Gemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRefA tensor_a, - TensorRefB tensor_b, - ScalarType beta, - TensorRefC tensor_c, - TensorRefC tensor_d, - AccumulatorType initial_accum) { - - // Map each thread to a unique tile of the output matrix - MatrixCoord output_coord( - MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), - MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) - ); - - // Compute the general matrix product - thread::Gemm< - TensorRefA, - TensorRefB, - TensorRefC, - ScalarType, - AccumulatorType, - OutputTile, - InnerProductOp, - ConvertOp - > gemm(initial_accum); - - gemm.multiply_add( - problem_size, - tensor_a, - tensor_b, - output_coord); - - gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename TensorRefCollectionA, - typename TensorRefCollectionB, - typename TensorRefCollectionC, - typename ScalarType, - typename AccumulatorType, - typename OutputTile, - typename InnerProductOp, - typename ConvertOp -> -__global__ void BatchedGemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRefCollectionA tensor_collection_a, - TensorRefCollectionB tensor_collection_b, - ScalarType beta, - TensorRefCollectionC tensor_collection_c, - AccumulatorType initial_accum) { - - // Obtain batch ID - int batch_id = blockIdx.z; - - // Dereference based on batch_id - typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); - typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); - typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); - - // Map each thread to a unique tile of the output matrix - MatrixCoord output_coord( - (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, - (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow - ); - - // Compute the general matrix product - thread::Gemm< - typename TensorRefCollectionA::TensorRef, - typename TensorRefCollectionB::TensorRef, - typename TensorRefCollectionC::TensorRef, - ScalarType, - AccumulatorType, - OutputTile, - InnerProductOp, - ConvertOp - > gemm(initial_accum); - - gemm.multiply_add( - problem_size, - tensor_a, - tensor_b, - output_coord); - - gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h deleted file mode 100644 index 149e4b2e00e2ac8130cee9dc189a539ba3a70297..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +++ /dev/null @@ -1,168 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace reference { -namespace device { -namespace kernel { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to initialize tensor to uniform random distribution -template -__global__ void TensorInitializeUniform( - Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { - __shared__ curandState_t rng_state[1024]; - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; - - curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); - - int c_idx = blockIdx.x * blockDim.x + threadIdx.x; - int s_idx = blockIdx.y * blockDim.x; - - tensor += s_idx * ldm + c_idx; - - for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { - if (s_idx < dim_strided && c_idx < dim_contiguous) { - double range = dist.uniform.max - dist.uniform.min; - - double rnd = curand_uniform(&rng_state[threadIdx.x]); - - rnd = dist.uniform.min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - if (dist.int_scale >= 0) { - rnd = double(int(rnd * double(1 << dist.int_scale))); - *tensor = T(rnd / double(1 << dist.int_scale)); - } else { - *tensor = T(rnd); - } - - tensor += ldm; - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to initialize tensor to uniform distribution -template -__global__ void TensorInitializeGaussian( - Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { - __shared__ curandState_t rng_state[1024]; - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; - - curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); - - int c_idx = blockIdx.x * blockDim.x + threadIdx.x; - int s_idx = blockIdx.y * blockDim.x; - - tensor += s_idx * ldm + c_idx; - - for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { - if (s_idx < dim_strided && c_idx < dim_contiguous) { - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - - double rnd = curand_normal(&rng_state[threadIdx.x]); - - rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; - - if (dist.int_scale >= 0) { - rnd = double(int(rnd * double(1 << dist.int_scale))); - *tensor = T(rnd / double(1 << dist.int_scale)); - } else { - *tensor = T(rnd); - } - } - } -} - -/// Kernel to initialize tensor to an identity matrix -template -__global__ void TensorInitializeLinear( - Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { - __shared__ curandState_t rng_state[1024]; - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; - - curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); - - int c_idx = blockIdx.x * blockDim.x + threadIdx.x; - int s_idx = blockIdx.y * blockDim.x; - - tensor += s_idx * ldm + c_idx; - - for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { - if (s_idx < dim_strided && c_idx < dim_contiguous) { - *tensor = - dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; - } - } -} - -/// Kernel to initialize tensor to an identity matrix -template -__global__ void TensorInitializeIdentity( - Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { - __shared__ curandState_t rng_state[1024]; - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; - - curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); - - int c_idx = blockIdx.x * blockDim.x + threadIdx.x; - int s_idx = blockIdx.y * blockDim.x; - - tensor += s_idx * ldm + c_idx; - - for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { - if (s_idx < dim_strided && c_idx < dim_contiguous) { - *tensor = (c_idx == s_idx ? T(1) : T(0)); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h deleted file mode 100644 index 3223cb2056ba6d88f47f7b117392a56e325d0ce7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +++ /dev/null @@ -1,159 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/coord.h" -#include "cutlass/subbyte_reference.h" -#include "cutlass/fast_math.h" - -namespace cutlass { -namespace reference { -namespace device { -namespace kernel { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines several helpers -namespace detail { - -/// Helper to perform for-each operation -template -struct TensorForEachHelper { - - /// Constructor for general rank - __inline__ __device__ - TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { - - int64_t product = 1; - - CUTLASS_PRAGMA_UNROLL - for (int i = Rank - RankRemaining; i < Rank; ++i) { - product *= size[i]; - } - - coord[Rank - 1 - RankRemaining] = index / product; - int64_t remaining = index % product; - - TensorForEachHelper(func, size, coord, remaining); - } -}; - -/// Helper to perform for-each operation -template -struct TensorForEachHelper { - - /// Constructor for fastest changing rank - __inline__ __device__ - TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { - - coord[Rank - 1] = index; - - if (coord < size) { - func(coord); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel calls a functor for each element in a tensor's index space -template -__global__ void TensorForEach(Coord size, Params params = Params()) { - - Func func(params); - - int64_t index = threadIdx.x + blockIdx.x * blockDim.x; - int64_t max_index = 1; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - max_index *= size[i]; - } - - CUTLASS_PRAGMA_NO_UNROLL - while (index < max_index) { - Coord coord; - - detail::TensorForEachHelper(func, size, coord, index); - index += blockDim.x * gridDim.x; - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel calls a functor for each element along a tensor's diagonal -template -__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { - - Func func(params); - - int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; - - if (index < end) { - Coord coord; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Rank; ++i) { - coord[i] = index; - } - - func(coord); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void BlockForEach( - Element *ptr, - size_t capacity, - typename Func::Params params) { - - Func func(params); - - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - - for (; index < capacity; index += blockDim.x * gridDim.x) { - ReferenceFactory::get(ptr, index) = func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace device -} // namespace reference -} // namespace cutlass - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h deleted file mode 100644 index 2e76fe52b06f9bb1a033c736f94fa01961ce664d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h +++ /dev/null @@ -1,355 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued GEMM in device-side code. -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add, - int kMblock = 4, - int kNblock = 4 -> -__global__ void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - FillMode fill_mode_c, - BlasMode blas_mode, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - assert(M=N); - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; - int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; - int batch_idx = blockIdx.z; - - tensor_a.add_pointer_offset(batch_idx * batch_stride_A); - tensor_b.add_pointer_offset(batch_idx * batch_stride_B); - tensor_c.add_pointer_offset(batch_idx * batch_stride_C); - tensor_d.add_pointer_offset(batch_idx * batch_stride_D); - - for (; batch_idx < batch_count; batch_idx += gridDim.z) { - - // Compute matrix product using blocks - ComputeType accum[kMblock][kNblock]; - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N && - ( (fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col) ) - ) { - - // A x B^T (Symmetric) or A x B^H (Hermitian) - // complex conjugation on operandB (b_t) is function of blas3 computation - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b_t = (blas_mode == BlasMode::kHermitian) ? - conj(tensor_b.at(MatrixCoord(col, k_block))) : - tensor_b.at(MatrixCoord(col, k_block)); - - ComputeType a_ik = ComputeType(a); - ComputeType b_jk = ComputeType(b_t); - - // complex conjugation is a function of operand layouts - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } - // complex conjugation is a function of operand layouts - if (transform_b == ComplexTransform::kConjugate) { - b_jk = conj(b_jk); - } - - accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); - - // B x A^T (Symmetric) or B x A^H (Hermitian) - // complex conjugation on operandB (a_t) is function of blas3 computation - ElementB b = tensor_b.at(MatrixCoord(row, k_block)); - ElementA a_t = (blas_mode == BlasMode::kHermitian) ? - conj(tensor_a.at(MatrixCoord(col, k_block))): - tensor_a.at(MatrixCoord(col, k_block)); - - ComputeType b_ik = ComputeType(b); - ComputeType a_jk = ComputeType(a_t); - - // complex conjugation here is a function of operand layouts - if (transform_b == ComplexTransform::kConjugate) { - b_ik = conj(b_ik); - } - // complex conjugation here is a function of operand layouts - if (transform_a == ComplexTransform::kConjugate) { - a_jk = conj(a_jk); - } - - accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); - } - } - } - } - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kNblock; j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kMblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N && - ((fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col)) - ) { - - ScalarType c = tensor_c.at(coord); - // The imaginary parts of the diagonal elements of - // a complex data type are assumed and set to zero - if (blas_mode == BlasMode::kHermitian) { - c = (row == col) ? real(c) : c; - } - - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * c); - } - } - } - - tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); - tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); - tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); - tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); - - } // for (batch_idx) -} - -} // namespace kernel - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - FillMode fill_mode_c, - BlasMode blas_mode, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - int const kMblock = 4; - int const kNblock = 4; - - dim3 block(16, 8); - dim3 grid( - (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), - (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), - batch_count % std::numeric_limits::max() - ); - - kernel::Rank2KComplex< - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - ScalarType, - ComputeType, - ConvertOp, - InnerProductOp, - kMblock, - kNblock - ><<< grid, block >>>( - problem_size, - alpha, - tensor_a, - transform_a, - tensor_b, - transform_b, - beta, - tensor_c, - tensor_d, - initial_accum, - fill_mode_c, - blas_mode, - batch_count, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType -> -void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - FillMode fill_mode_c, - BlasMode blas_mode) { - - Rank2KComplex( - problem_size, alpha, - tensor_a, transform_a, - tensor_b, transform_b, - beta, tensor_c, tensor_d, - ScalarType(0), - fill_mode_c, - blas_mode); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h deleted file mode 100644 index 1999730f6d24e69aef152aa332fae68af57a9c40..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ /dev/null @@ -1,250 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines host-side elementwise operations on TensorView. -*/ - -#pragma once -// Standard Library includes -#include - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/relatively_equal.h" - -#include "cutlass/util/distribution.h" - -#include "tensor_foreach.h" - -namespace cutlass { -namespace reference { -namespace device { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -template -__global__ void BlockCompareEqual( - int *equal, - Element const *ptr_A, - Element const *ptr_B, - size_t capacity) { - - size_t idx = threadIdx.x + blockDim.x * blockIdx.x; - - for (; idx < capacity; idx += gridDim.x * blockDim.x) { - - Element a = cutlass::ReferenceFactory::get(ptr_A, idx); - Element b = cutlass::ReferenceFactory::get(ptr_B, idx); - - if (a != b) { - *equal = 0; - - return; - } - } -} - -template -__global__ void BlockCompareRelativelyEqual( - int *equal, - Element const *ptr_A, - Element const *ptr_B, - size_t capacity, - Element epsilon, - Element nonzero_floor) { - - size_t idx = threadIdx.x + blockDim.x * blockIdx.x; - - for (; idx < capacity; idx += gridDim.x * blockDim.x) { - - Element a = cutlass::ReferenceFactory::get(ptr_A, idx); - Element b = cutlass::ReferenceFactory::get(ptr_B, idx); - - if (!relatively_equal(a, b, epsilon, nonzero_floor)) { - *equal = 0; - return; - } - } -} - -} // namespace kernel - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Performs a bit-level equality check between two blocks -template -bool BlockCompareEqual( - Element const *ptr_A, - Element const *ptr_B, - size_t capacity, - int grid_size = 0, - int block_size = 0, - cudaStream_t stream = nullptr) { - - int equal_flag = 1; - int *device_equal_flag = nullptr; - - if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { - throw std::runtime_error("Failed to allocate device flag."); - } - - if (cudaMemcpy( - device_equal_flag, - &equal_flag, - sizeof(int), - cudaMemcpyHostToDevice) != cudaSuccess) { - - throw std::runtime_error("Failed to copy equality flag to device."); - } - - if (!grid_size || !block_size) { - - // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API - cudaError_t result = cudaOccupancyMaxPotentialBlockSize( - &grid_size, - &block_size, - reinterpret_cast(kernel::BlockCompareEqual)); - - if (result != cudaSuccess) { - throw std::runtime_error("Failed to query occupancy."); - } - // Limit block size. This has the effect of increasing the number of items processed by a - // single thread and reduces the impact of initialization overhead. - block_size = (block_size < 128 ? block_size : 128); - } - - dim3 grid(grid_size, 1, 1); - dim3 block(block_size, 1, 1); - - kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); - - cudaStreamSynchronize(stream); - - if (cudaMemcpy( - &equal_flag, - device_equal_flag, - sizeof(int), - cudaMemcpyDeviceToHost) != cudaSuccess) { - - cudaFree(device_equal_flag); - - throw std::runtime_error("Failed to copy equality flag from device."); - } - - cudaFree(device_equal_flag); - - return equal_flag; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Performs a bit-level equality check between two blocks -template -bool BlockCompareRelativelyEqual( - Element const *ptr_A, - Element const *ptr_B, - size_t capacity, - Element epsilon, - Element nonzero_floor, - int grid_size = 0, - int block_size = 0, - cudaStream_t stream = nullptr) { - - int equal_flag = 1; - int *device_equal_flag = nullptr; - - if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { - throw std::runtime_error("Failed to allocate device flag."); - } - - if (cudaMemcpy( - device_equal_flag, - &equal_flag, - sizeof(int), - cudaMemcpyHostToDevice) != cudaSuccess) { - - throw std::runtime_error("Failed to copy equality flag to device."); - } - - if (!grid_size || !block_size) { - - // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API - cudaError_t result = cudaOccupancyMaxPotentialBlockSize( - &grid_size, - &block_size, - reinterpret_cast(kernel::BlockCompareRelativelyEqual)); - - if (result != cudaSuccess) { - throw std::runtime_error("Failed to query occupancy."); - } - // Limit block size. This has the effect of increasing the number of items processed by a - // single thread and reduces the impact of initialization overhead. - block_size = (block_size < 128 ? block_size : 128); - } - - dim3 grid(grid_size, 1, 1); - dim3 block(block_size, 1, 1); - - kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( - device_equal_flag, - ptr_A, - ptr_B, - capacity, - epsilon, - nonzero_floor - ); - - cudaStreamSynchronize(stream); - - if (cudaMemcpy( - &equal_flag, - device_equal_flag, - sizeof(int), - cudaMemcpyDeviceToHost) != cudaSuccess) { - - cudaFree(device_equal_flag); - - throw std::runtime_error("Failed to copy equality flag from device."); - } - - cudaFree(device_equal_flag); - - return equal_flag; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // device -} // reference -} // cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h deleted file mode 100644 index a19b42825f6efb4a39466fe1cfc182ab7d831079..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ /dev/null @@ -1,2075 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines device-side elementwise operations on TensorView. Note, the operations defined - in this header are not specialized for any particular data layout and are therefore not - intended to offer the best possible performance. Rather, they are intended to be generic - reference implementations to support the CUTLASS unit tests. -*/ - -#pragma once - -#if !defined(__CUDACC_RTC__) - -// Standard Library includes -#include -#include -#include -#include -#include - -#endif - -// CUDA includes -#include - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/tensor_view.h" -#include "cutlass/blas3.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/layout/vector.h" - -#include "cutlass/util/reference/device/tensor_foreach.h" -#include "cutlass/util/distribution.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace device { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -CUTLASS_DEVICE -FloatType random_normal_float(curandState_t *state) { - return curand_normal(state); -} - -template <> -CUTLASS_DEVICE -double random_normal_float(curandState_t *state) { - return curand_normal_double(state); -} - -template -CUTLASS_DEVICE -FloatType random_uniform_float(curandState_t *state) { - return curand_uniform(state); -} - -template <> -CUTLASS_DEVICE -double random_uniform_float(curandState_t *state) { - return curand_uniform_double(state); -} - -template -struct RandomGaussianFunc { - - using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; - using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; - - /// Parameters structure - struct Params { - - // - // Data members - // - - uint64_t seed; - FloatType mean; - FloatType stddev; - int int_scale; - FloatType float_scale_up; - FloatType float_scale_down; - int exclude_zero; ///< If non-negative, excludes zeros - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - uint64_t seed_ = 0, - Element mean_ = 0, - Element stddev_ = 1, - int int_scale_ = -1, - int exclude_zero_ = -1 - ): - seed(seed_), - mean(static_cast(mean_)), - stddev(static_cast(stddev_)), - int_scale(int_scale_), - exclude_zero(exclude_zero_) { - - float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - /// RNG state object - curandState_t rng_state; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - RandomGaussianFunc(Params const ¶ms): params(params) { - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; - - curand_init(params.seed, gtid, 0, &rng_state); - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - Element operator()() { - - FloatType rnd = random_normal_float(&rng_state); - rnd = params.mean + params.stddev * rnd; - - Element result; - if (params.int_scale >= 0) { - rnd = FloatType(std::llround(rnd * params.float_scale_up)); - result = Element(rnd * params.float_scale_down); - } - else { - result = Element(rnd); - } - - if (params.exclude_zero >=0 && result == Element(0.0)) { - if (rnd > FloatType(0)) { - rnd += FloatType(1); - } else { - rnd -= FloatType(1); - } - result = Element(rnd); - } - - return result; - } -}; - - -template -struct RandomGaussianFunc> { - - using Element = complex; - using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; - using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; - - /// Parameters structure - struct Params { - - // - // Data members - // - - uint64_t seed; - FloatType mean; - FloatType stddev; - int int_scale; - FloatType float_scale_up; - FloatType float_scale_down; - int exclude_zero; ///< If non-negative, excludes zeros - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - uint64_t seed_ = 0, - Real mean_ = 0, - Real stddev_ = 1, - int int_scale_ = -1, - int exclude_zero_ = -1 - ): - seed(seed_), - mean(static_cast(mean_)), - stddev(static_cast(stddev_)), - int_scale(int_scale_), - exclude_zero(exclude_zero_) { - - float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - /// RNG state object - curandState_t rng_state; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - RandomGaussianFunc(Params const ¶ms): params(params) { - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; - - curand_init(params.seed, gtid, 0, &rng_state); - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - Element operator()() { - - FloatType rnd_r = random_normal_float(&rng_state); - FloatType rnd_i = random_normal_float(&rng_state); - rnd_r = params.mean + params.stddev * rnd_r; - rnd_i = params.mean + params.stddev * rnd_i; - - Element result; - if (params.int_scale >= 0) { - rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); - rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); - - result = { - Real(rnd_r * params.float_scale_down), - Real(rnd_i * params.float_scale_down) - }; - } - else { - result = Element(Real(rnd_r), Real(rnd_i)); - } - - if (params.exclude_zero >= 0 && - result.real() == Real(0.0) && - result.imag() == Real(0.0)) { - - if (rnd_r > FloatType(0)) { - rnd_r += FloatType(1); - } else { - rnd_r -= FloatType(1); - } - result = Element(Real(rnd_r), Real(rnd_i)); - } - - return result; - } -}; - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillRandomGaussianFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - using RandomFunc = RandomGaussianFunc; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - typename RandomFunc::Params random; - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_ = TensorView(), - typename RandomFunc::Params random_ = typename RandomFunc::Params() - ): - view(view_), random(random_) { - - } - }; - - // - // Data members - // - - Params params; - RandomFunc random; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { - - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - params.view.at(coord) = random(); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a Gaussian distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomGaussian( - TensorView view, ///< destination tensor - uint64_t seed, ///< seed for RNG - typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean - typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomGaussianFunc; - using Func = detail::TensorFillRandomGaussianFunc; - using Params = typename Func::Params; - - TensorForEach( - view.extent(), - Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a Gaussian distribution. -template ///< Element type -void BlockFillRandomGaussian( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - typename RealType::Type mean, ///< Gaussian distribution's mean - typename RealType::Type stddev, ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomGaussianFunc; - - typename RandomFunc::Params params(seed, mean, stddev, bits); - - BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random uniform distribution -template ///< Element type -struct RandomUniformFunc { - - using FloatType = typename std::conditional< - (sizeof(Element) > 4), - double, - float>::type; - - using IntType = typename std::conditional< - (sizeof(Element) > 4), - int64_t, - int>::type; - - /// Parameters structure - struct Params { - - // - // Data members - // - - uint64_t seed; - FloatType range; - FloatType max; - int int_scale; - double pnan; - FloatType float_scale_up; - FloatType float_scale_down; - int exclude_zero; ///< If non-negative, excludes zeros - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - uint64_t seed_ = 0, - Element max_ = 1, - Element min = 0, - int int_scale_ = -1, - double pnan_ = 0, - int exclude_zero_ = -1 - ): - seed(seed_), - range(static_cast(max_) - static_cast(min)), - max(static_cast(max_)), - int_scale(int_scale_), - pnan(pnan_), - exclude_zero(exclude_zero_) { - - float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); - - // Handle cases where min = 0 or max = 0 for excluding zeros - if (exclude_zero >= 0) { - range = (min == Element(0)) ? range - FloatType(1): range; - max = (max_ == Element(0)) ? max - FloatType(1): max; - } - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - /// RNG state object - curandState_t rng_state; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - RandomUniformFunc(Params const ¶ms): params(params) { - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; - - curand_init(params.seed, gtid, 0, &rng_state); - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - Element operator()() { - - // Draw random float in [0.0, 1.0] to determine if element should be NaN. - if constexpr (std::numeric_limits::has_quiet_NaN) { - if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { - return Element(NAN); - } - } - - FloatType rnd = random_uniform_float(&rng_state); - rnd = params.max - params.range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - Element result; - - if (params.int_scale >= 0) { - rnd = FloatType(std::llround(rnd * params.float_scale_up)); - result = Element(rnd * params.float_scale_down); - } - else { - result = Element(rnd); - } - - if (params.exclude_zero >=0 && result == Element(0.0)) { - if (rnd > FloatType(0)) { - rnd = std::min(params.max, rnd + FloatType(1)); - } else { - rnd = std::max((params.max - params.range), rnd - FloatType(1)); - } - result = Element(rnd); - } - - return result; - } -}; - -/// Computes a random Gaussian distribution -template -struct RandomUniformFunc> { - - using Element = complex; - - using FloatType = typename std::conditional< - (sizeof(Real) > 4), - double, - float>::type; - - using IntType = typename std::conditional< - (sizeof(Real) > 4), - int64_t, - int>::type; - - /// Parameters structure - struct Params { - - // - // Data members - // - - uint64_t seed; - FloatType range; - FloatType min; - int int_scale; - double pnan; - FloatType float_scale_up; - FloatType float_scale_down; - int exclude_zero; ///< If non-negative, excludes zeros - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - uint64_t seed_ = 0, - FloatType max = 1, - FloatType min_ = 0, - int int_scale_ = -1, - double pnan_ = 0, - int exclude_zero_ = -1 - ): - seed(seed_), - range(static_cast(max - min_)), - min(static_cast(min_)), - int_scale(int_scale_), - pnan(pnan_), - exclude_zero(exclude_zero_) { - - float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); - - // Handle cases where min = 0 or max = 0 for excluding zeros - if (exclude_zero >= 0) { - min = (min == FloatType(0)) ? min + FloatType(1): min; - range = (max == FloatType(0)) ? range - FloatType(1): range; - } - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - /// RNG state object - curandState_t rng_state; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - RandomUniformFunc(Params const ¶ms): params(params) { - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; - - curand_init(params.seed, gtid, 0, &rng_state); - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - Element operator()() { - - // Draw random float in [0.0, 1.0] to determine if element should be NaN. - if constexpr (std::numeric_limits::has_quiet_NaN) { - if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { - return Element(Real(NAN), Real(NAN)); - } - } - - FloatType rnd_r = random_uniform_float(&rng_state); - FloatType rnd_i = random_uniform_float(&rng_state); - - rnd_r = params.min + params.range * rnd_r; - rnd_i = params.min + params.range * rnd_i; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - Element result; - - if (params.int_scale >= 0) { - rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); - rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); - - result = { - Real(rnd_r * params.float_scale_down), - Real(rnd_i * params.float_scale_down) - }; - } - else { - result = Element(Real(rnd_r), Real(rnd_i)); - } - - if (params.exclude_zero >= 0 && - result.real() == Real(0.0) && - result.imag() == Real(0.0)) { - - if (rnd_r > FloatType(0)) { - rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); - } else { - rnd_r = std::max((params.min), rnd_r - FloatType(1)); - } - result = Element(Real(rnd_r), Real(rnd_i)); - } - - return result; - } -}; - -/// Computes a random uniform distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillRandomUniformFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - using RandomFunc = RandomUniformFunc; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - typename RandomFunc::Params random; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_ = TensorView(), - typename RandomFunc::Params random_ = RandomFunc::Params() - ): - view(view_), random(random_) { - - } - }; - - // - // Data members - // - - Params params; - RandomFunc random; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - params.view.at(coord) = random(); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomUniform( - TensorView view, ///< destination tensor - uint64_t seed, ///< seed for RNG - typename RealType::Type max = Element(1), ///< upper bound of distribution - typename RealType::Type min = Element(0), ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - double pnan = 0, ///< Percentage of NaN elements. - int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomUniformFunc; - using Func = detail::TensorFillRandomUniformFunc; - using Params = typename Func::Params; - - typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); - - TensorForEach( - view.extent(), - Params(view, random), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template -void BlockFillRandomUniform( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - typename RealType::Type max, ///< upper bound of distribution - typename RealType::Type min, ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - double pnan = 0, ///< Percentage of NaN elements. - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomUniformFunc; - - typename RandomFunc::Params params(seed, max, min, bits, pnan); - - BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random sparse meta -template ///< Element type -struct RandomSparseMetaFunc { - - using FloatType = float; - - using IntType = int32_t; - - /// Parameters structure - struct Params { - - // - // Data members - // - - uint64_t seed; - FloatType range; - int MetaSizeInBits; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - uint64_t seed_ = 0, - int MetaSizeInBits_ = 2 - ): - seed(seed_), - MetaSizeInBits(MetaSizeInBits_) { - if (MetaSizeInBits_ == 2) { - range = 6; - } - else if (MetaSizeInBits_ == 4) { - range = 2; - } - else { - throw std::invalid_argument("Invalid MetaSizeInBits"); - } - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - /// RNG state object - curandState_t rng_state; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - RandomSparseMetaFunc(Params const ¶ms): params(params) { - - uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; - - curand_init(params.seed, gtid, 0, &rng_state); - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - Element operator()() { - Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; - Element TwoToOneMeta[2] = {0x4, 0xe}; - - Element *MetaArray = - (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; - - Element result = 0x0; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { - FloatType rnd = random_uniform_float(&rng_state); - rnd = params.range * rnd; - Element meta = MetaArray[(int)rnd]; - - result = (Element)(result | ((Element)(meta << (i * 4)))); - } - - return result; - } -}; - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillRandomSparseMetaFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - using RandomFunc = RandomSparseMetaFunc; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - typename RandomFunc::Params random; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_ = TensorView(), - typename RandomFunc::Params random_ = RandomFunc::Params() - ): - view(view_), random(random_) { - - } - }; - - // - // Data members - // - - Params params; - RandomFunc random; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - params.view.at(coord) = random(); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomSparseMeta( - TensorView view, ///< destination tensor - uint64_t seed, ///< seed for RNG - int MetaSizeInBits = 2, ///< meta data size - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomSparseMetaFunc; - using Func = detail::TensorFillRandomUniformFunc; - using Params = typename Func::Params; - - typename RandomFunc::Params random(seed, MetaSizeInBits); - - TensorForEach( - view.extent(), - Params(view, random), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template -void BlockFillRandomSparseMeta( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - int MetaSizeInBits = 2, ///< meta data size - cudaStream_t stream = nullptr) { - - using RandomFunc = detail::RandomSparseMetaFunc; - - typename RandomFunc::Params params(seed, MetaSizeInBits); - - BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillDiagonalFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element diag; - Element other; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - Params( - TensorView view_ = TensorView(), - Element diag_ = Element(1), - Element other_ = Element(0) - ): - view(view_), diag(diag_), other(other_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorFillDiagonalFunc(Params const ¶ms): params(params) { - - } - - /// Updates the tensor - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - bool is_diag = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[i - 1]) { - is_diag = false; - break; - } - } - - params.view.at(coord) = (is_diag ? params.diag : params.other); - } -}; - -// Overwrites the elements of a tensor with a uniform value depending on fill mode -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillPartialFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element element; - FillMode fill_mode; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params(): fill_mode(FillMode::kNone) { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_, - Element element_, - FillMode fill_mode_ - ): - view(view_), element(element_), fill_mode(fill_mode_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - CUTLASS_DEVICE - TensorFillPartialFunc(Params const ¶ms): params(params) { - - } - - /// Overwrites the element if it is within the covered region. - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - bool predicate = true; - - switch (params.fill_mode) { - case FillMode::kFull: - predicate = true; - break; - - case FillMode::kLower: - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i - 1] < coord[i]) { - predicate = false; - break; - } - } - break; - - case FillMode::kUpper: - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i - 1] > coord[i]) { - predicate = false; - break; - } - } - break; - - case FillMode::kDiagonal: - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i - 1] != coord[i]) { - predicate = false; - break; - } - } - break; - - case FillMode::kNone: // fall-through - - default: - predicate = false; - break; - } - - if (predicate) { - params.view.at(coord) = params.element; - } - } -}; - - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorClearPartialFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// - static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); - - /// Parameters structure - struct Params { - TensorView view{}; - Element element{}; - FillMode fill_mode{FillMode::kNone}; - int alignment{0}; - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - CUTLASS_DEVICE - TensorClearPartialFunc(Params const ¶ms): params(params) { - - } - - /// Overwrites the element if it is within the covered region. - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - bool predicate = true; - - switch (params.fill_mode) { - - case FillMode::kLower: - if ((coord[0] >= coord[1]) || - ((coord[1] - coord[0]) >= params.alignment)) { - predicate = false; - break; - } - break; - - case FillMode::kUpper: - if ((coord[0] <= coord[1]) || - ((coord[0] - coord[1]) >= params.alignment)) { - predicate = false; - break; - } - break; - - case FillMode::kNone: // fall-through - - default: - predicate = false; - break; - } - - if (predicate) { - params.view.at(coord) = params.element; - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor everywhere with a unique value for its diagonal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillDiagonal( - TensorView view, ///< destination tensor - Element diag = Element(1), ///< value to write in the diagonal - Element other = Element(0), ///< value to write off the diagonal - cudaStream_t stream = nullptr) { - - typedef detail::TensorFillDiagonalFunc Func; - typedef typename Func::Params Params; - - TensorForEach( - view.extent(), - Params(view, diag, other), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are -/// not written. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillPartial( - TensorView view, ///< destination tensor - Element element, - FillMode fill_mode, - cudaStream_t stream = nullptr) { - - typedef detail::TensorFillPartialFunc Func; - typedef typename Func::Params Params; - - TensorForEach( - view.extent(), - Params(view, element, fill_mode), - stream - ); -} - -/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side -/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorClearPartial( - TensorView view, ///< destination tensor - Element element, - FillMode fill_mode, - int alignment, - cudaStream_t stream = nullptr) { - - typedef detail::TensorClearPartialFunc Func; - typedef typename Func::Params Params; - - TensorForEach( - view.extent(), - Params{view, element, fill_mode, alignment}, - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with a uniform value -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFill( - TensorView view, ///< destination tensor - Element val = Element(0), ///< value to uniformly fill it with - cudaStream_t stream = nullptr) { - - TensorFillDiagonal(view, val, val, stream); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor's diagonal with 1 and 0 everywhere else. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillIdentity( - TensorView view, ///< destination tensor - cudaStream_t stream = nullptr) { - - TensorFillDiagonal(view, Element(1), Element(0), stream); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorUpdateDiagonalFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element diag; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_ = TensorView(), - Element diag_ = Element(1) - ): - view(view_), diag(diag_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { - - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - bool is_diag = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[i - 1]) { - is_diag = false; - break; - } - } - - if (is_diag) { - params.view.at(coord) = params.diag; - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorUpdateDiagonal( - TensorView view, ///< destination tensor - Element diag = Element(1), - cudaStream_t stream = nullptr) { - - typedef detail::TensorUpdateDiagonalFunc Func; - typedef typename Func::Params Params; - - TensorForEach( - view.extent(), - Params(view, diag), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorUpdateOffDiagonalFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element other; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_ = TensorView(), - Element other_ = Element(0) - ): - view(view_), other(other_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { - - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - bool is_diag = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[i - 1]) { - is_diag = false; - break; - } - } - - if (!is_diag) { - params.view.at(coord) = params.other; - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorUpdateOffDiagonal( - TensorView view, ///< destination tensor - Element other = Element(1), - cudaStream_t stream = nullptr) { - - typedef detail::TensorUpdateOffDiagonalFunc Func; - typedef typename Func::Params Params; - - TensorForEach( - view.extent(), - Params(view, other), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillLinearFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Array v; - Element s; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_, ///< destination tensor - Array const & v_, - Element s_ = Element(0) - ): - view(view_), v(v_), s(s_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorFillLinearFunc(Params const ¶ms): params(params) { - - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - Element sum = params.s; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Layout::kRank; ++i) { - if constexpr (is_complex::value) { - if constexpr (sizeof_bits::value <= 32) { - sum = Element(static_cast>(sum) + - static_cast>(params.v[i]) * static_cast>(coord[i])); - } - } - else if constexpr (sizeof_bits::value <= 32) { - if constexpr (std::numeric_limits::is_integer) { - sum = Element(static_cast(sum) + - static_cast(params.v[i]) * static_cast(coord[i])); - } - else { - sum = Element(static_cast(sum) + - static_cast(params.v[i]) * static_cast(coord[i])); - } - } - else { - sum += params.v[i] * coord[i]; - } - } - - params.view.at(coord) = sum; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills tensor with a linear combination of its coordinate and another vector -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillLinear( - TensorView view, ///< destination tensor - Array const & v, - Element s = Element(0), - cudaStream_t stream = nullptr) { - - using Func = detail::TensorFillLinearFunc; - using Params = typename Func::Params; - - TensorForEach( - view.extent(), - Params(view, v, s), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values from a distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandom( - TensorView view, ///< destination tensor - uint64_t seed, - Distribution dist, - cudaStream_t stream = nullptr, - int exclude_zero = -1 ///< If non-negative, excludes 0. - /// Note that setting this flag will result in more 1's, - /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. - ) { - - using Real = typename RealType::Type; - - if (dist.kind == Distribution::Gaussian) { - TensorFillRandomGaussian( - view, - seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), - dist.int_scale, - exclude_zero, - stream); - } else if (dist.kind == Distribution::Uniform) { - TensorFillRandomUniform( - view, - seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), - dist.int_scale, - dist.uniform.pnan, - exclude_zero, - stream); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillSequential( - Element *ptr, - int64_t capacity, - Element v = Element(1), - Element s = Element(0)) { - - using Layout = layout::PackedVectorLayout; - Layout::TensorCoord size(static_cast(capacity)); // -Wconversion - Layout layout = Layout::packed(size); - TensorView view(ptr, layout, size); - - Array c{}; - c[0] = v; - - TensorFillLinear(view, c, s); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillRandom( - Element *ptr, - size_t capacity, - uint64_t seed, - Distribution dist, - cudaStream_t stream = nullptr) { - - using Real = typename RealType::Type; - - if (dist.kind == Distribution::Gaussian) { - BlockFillRandomGaussian( - ptr, - capacity, - seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), - dist.int_scale, - stream); - } - else if (dist.kind == Distribution::Uniform) { - BlockFillRandomUniform( - ptr, - capacity, - seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), - dist.int_scale, - dist.uniform.pnan, - stream); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorCopyDiagonalInFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element const *ptr; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_, ///< destination tensor - Element const *ptr_ - ): - view(view_), ptr(ptr_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { - - } - - /// Only update the diagonal element - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - bool is_diagonal = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[0]) { - is_diagonal = false; - } - } - if (is_diagonal) { - params.view.at(coord) = params.ptr[coord[0]]; - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies a diagonal in from host memory without modifying off-diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorCopyDiagonalIn( - TensorView view, ///< destination tensor - Element const *ptr, ///< dense buffer of elements - cudaStream_t stream = nullptr) { - - using Func = detail::TensorCopyDiagonalInFunc; - using Params = typename Func::Params; - - TensorForEach( - view.extent(), - Params(view, ptr), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - - -namespace detail { - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorCopyDiagonalOutFunc { - - /// View type - using TensorView = TensorView; - - /// Scalar type - typedef typename TensorView::Element T; - - /// Coordinate in tensor's index space - typedef typename TensorView::TensorCoord TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element *ptr; - - /// Default ctor - CUTLASS_HOST_DEVICE - Params() { } - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - Params( - TensorView view_, ///< destination tensor - Element *ptr_ - ): - view(view_), ptr(ptr_) { - - } - }; - - // - // Data members - // - - /// Parameters object - Params params; - - // - // Methods - // - - /// Device-side initialization of RNG - CUTLASS_DEVICE - TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { - - } - - /// Compute random value and update RNG state - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - bool is_diagonal = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[0]) { - is_diagonal = false; - } - } - if (is_diagonal) { - params.ptr[coord[0]] = params.view.at(coord); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies the diagonal of a tensor into a dense buffer in host memory. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorCopyDiagonalOut( - Element *ptr, ///< dense buffer of elements - TensorView view, ///< source tensor - cudaStream_t stream = nullptr) { - - using Func = detail::TensorCopyDiagonalOutFunc; - using Params = typename Func::Params; - - TensorForEach( - view.extent(), - Params(view, ptr), - /*grid_size*/0, /*block_size*/0, - stream - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h deleted file mode 100644 index ba2dfd85c47b8c9450c348de32dccb7f1be9c3c1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ /dev/null @@ -1,142 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include -#include "cutlass/cutlass.h" -#include "cutlass/util/reference/device/kernel/tensor_foreach.h" - -namespace cutlass { -namespace reference { -namespace device { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Launches a kernel calling a functor for each element in a tensor's index space. -template -struct TensorForEach { - - /// Constructor performs the operation. - TensorForEach( - Coord size, Params params = Params(), - int grid_size = 0, int block_size = 0, - cudaStream_t stream = nullptr) { - - if (!grid_size || !block_size) { - - // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API - cudaError_t result = cudaOccupancyMaxPotentialBlockSize( - &grid_size, - &block_size, - reinterpret_cast(kernel::TensorForEach)); - - if (result != cudaSuccess) { - throw std::runtime_error("Failed to query occupancy."); - } - // Limit block size. This has the effect of increasing the number of items processed by a - // single thread and reduces the impact of initialization overhead. - block_size = (block_size < 128 ? block_size : 128); - } - - dim3 grid(grid_size, 1, 1); - dim3 block(block_size, 1, 1); - - kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Launches a kernel calling a functor for each element along a tensor's diagonal -template -struct TensorDiagonalForEach { - - /// Constructor performs the operation - TensorDiagonalForEach( - Coord size, Params params = Params(), - int start = 0, int end = -1, - int block_size = 128, cudaStream_t stream = nullptr) { - - if (end < 0) { - end = size.min(); - } - - dim3 block(block_size, 1, 1); - dim3 grid((end - start + block_size - 1) / block_size, 1, 1); - - kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( - size, params, start, end); - } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockForEach { - - /// Constructor performs the operation. - BlockForEach( - Element *ptr, - size_t capacity, - typename Func::Params params = typename Func::Params(), - int grid_size = 0, - int block_size = 0, - cudaStream_t stream = nullptr) { - - if (!grid_size || !block_size) { - - // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API - cudaError_t result = cudaOccupancyMaxPotentialBlockSize( - &grid_size, - &block_size, - reinterpret_cast(kernel::BlockForEach)); - - if (result != cudaSuccess) { - throw std::runtime_error("Failed to query occupancy."); - } - // Limit block size. This has the effect of increasing the number of items processed by a - // single thread and reduces the impact of initialization overhead. - block_size = (block_size < 128 ? block_size : 128); - } - - dim3 grid(grid_size, 1, 1); - dim3 block(block_size, 1, 1); - - kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h deleted file mode 100644 index 3e6d7b300f34fec6aec96e72f78427cf677936b4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +++ /dev/null @@ -1,514 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/reference/detail/linear_to_coordinate.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace device { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace kernel { - -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp, - int kBlockSize = 128 -> -__global__ void TensorTransformReducePartial( - TensorView view, /// View of the tensor to reduce over - ComputeType identity, /// Identity element of the reduction operation - ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType - TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType - ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] - - int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; - int64_t size = view.size(); - - __shared__ ComputeType scratchpad[kBlockSize]; - - for (; idx < size; idx += blockDim.x * gridDim.x) { - - // Map linear thread ID onto tensor coordinate - typename Layout::TensorCoord coord; - - cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); - - if (view.contains(coord)) { - - // Fetch element - Element x = view.at(coord); - - // Transform - identity = reduce(identity, transform(x)); - } - } - - scratchpad[threadIdx.x] = identity; - - __syncthreads(); - - // One thread performs the final reduction and stores out. This could be enhanced via - // a tree reduction and pipelining. - if (threadIdx.x == 0) { - - for (int i = 1; i < kBlockSize; ++i) { - identity = reduce(identity, scratchpad[i]); - } - - workspace[blockIdx.x] = identity; - } -} - -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp, - int kBlockSize = 128 -> -__global__ void TensorTransformReducePartial( - TensorView view_A, /// View of the tensor to reduce over - TensorView view_B, /// View of the tensor to reduce over - ComputeType identity, /// Identity element of the reduction operation - ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType - TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType - ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] - - int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; - auto size = static_cast(view_A.size()); - - __shared__ ComputeType scratchpad[kBlockSize]; - - for (; idx < size; idx += blockDim.x * gridDim.x) { - - // Map linear thread ID onto tensor coordinate - typename Layout::TensorCoord coord; - - cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); - - if (view_A.contains(coord)) { - - // Fetch element - Element a = view_A.at(coord); - Element b = view_B.at(coord); - - // Transform - identity = reduce(identity, transform(a, b)); - } - } - - scratchpad[threadIdx.x] = identity; - - __syncthreads(); - - // One thread performs the final reduction and stores out. This could be enhanced via - // a tree reduction and pipelining. - if (threadIdx.x == 0) { - - for (int i = 1; i < kBlockSize; ++i) { - identity = reduce(identity, scratchpad[i]); - } - - workspace[blockIdx.x] = identity; - } -} - - -template < - typename ComputeType, - typename ReduceOp, - int kBlockSize = 32 -> -__global__ void TensorTransformReduceFinalize( - ComputeType *workspace, - ComputeType identity, - int workspace_size, - ReduceOp reduce) { - - __shared__ ComputeType scratchpad[kBlockSize]; - - for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { - identity = reduce(identity, workspace[idx]); - } - - scratchpad[threadIdx.x] = identity; - - __syncthreads(); - - if (threadIdx.x == 0) { - - for (int i = 1; i < kBlockSize; ++i) { - identity = reduce(identity, scratchpad[i]); - } - - workspace[0] = identity; - } -} - -} // namespace kernel - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Transform-reduce operation over the elements of a tensor -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view, /// View of the tensor to reduce over - ComputeType identity, /// Identity element of the reduction operation - ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType - TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType - ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] - int workspace_size, /// Number of elements in workspace - cudaStream_t stream = nullptr, /// CUDA stream to launch into - bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -) { - - int const kBlockSize = 128; - - dim3 block(kBlockSize, 1); - dim3 grid(workspace_size, 1); - - kernel::TensorTransformReducePartial< - Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize - ><<< grid, block, 0, stream >>>( - view, identity, reduce, transform, workspace - ); - - int const kFinalizeBlockSize = 32; - - kernel::TensorTransformReduceFinalize< - ComputeType, ReduceOp, kFinalizeBlockSize - ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( - workspace, identity, workspace_size, reduce - ); - - cudaStreamSynchronize(stream); - - if (copy_out) { - cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); - if (result != cudaSuccess) { - throw std::runtime_error("cudaMemcpy() failed"); - } - } - - return identity; -} - -/// Transform-reduce operation over the elements of two tensors, zipped together -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view_A, /// View of the tensor to reduce over - TensorView view_B, /// View of the tensor to reduce over - ComputeType identity, /// Identity element of the reduction operation - ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType - TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType - ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] - int workspace_size, /// Number of elements in workspace - cudaStream_t stream = nullptr, /// CUDA stream to launch into - bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -) { - - if (view_A.extent() != view_B.extent()) { - throw std::runtime_error("Extents must be equal."); - } - - int const kBlockSize = 128; - - dim3 block(kBlockSize, 1); - dim3 grid(workspace_size, 1); - - kernel::TensorTransformReducePartial< - Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize - ><<< grid, block, 0, stream >>>( - view_A, view_B, identity, reduce, transform, workspace - ); - - int const kFinalizeBlockSize = 32; - - kernel::TensorTransformReduceFinalize< - ComputeType, ReduceOp, kFinalizeBlockSize - ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( - workspace, identity, workspace_size, reduce - ); - - cudaStreamSynchronize(stream); - - if (copy_out) { - cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); - if (result != cudaSuccess) { - throw std::runtime_error("cudaMemcpy() failed"); - } - } - - return identity; -} - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view, - ComputeType identity, - ReduceOp reduce, - TransformOp transform, - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - // Optionally query for the SM count to size the workspace. - if (!workspace_size) { - - int device_idx = 0; - cudaDeviceProp prop; - - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() failed"); - } - - result = cudaGetDeviceProperties(&prop, device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProp() failed"); - } - - workspace_size = int(prop.multiProcessorCount); - } - - DeviceAllocation workspace(workspace_size); - - ComputeType output = TensorTransformReduce( - view, - identity, - reduce, - transform, - workspace.get(), - workspace_size, - stream, - true); - - return output; -} - - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view_A, - TensorView view_B, - ComputeType identity, - ReduceOp reduce, - TransformOp transform, - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - // Optionally query for the SM count to size the workspace. - if (!workspace_size) { - - int device_idx = 0; - cudaDeviceProp prop; - - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() failed"); - } - - result = cudaGetDeviceProperties(&prop, device_idx); - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProp() failed"); - } - - workspace_size = int(prop.multiProcessorCount); - } - - DeviceAllocation workspace(workspace_size); - - ComputeType output = TensorTransformReduce( - view_A, - view_B, - identity, - reduce, - transform, - workspace.get(), - workspace_size, - stream, - true); - - return output; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to compute the sum of the elements of a tensor -template < - typename Element, - typename Layout, - typename ComputeType = Element -> -ComputeType TensorSum( - TensorView view, - ComputeType identity = ComputeType(), - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - plus reduce; - NumericConverter transform; - - return TensorTransformReduce( - view, identity, reduce, transform, stream, workspace_size); -} - -/// Helper to compute the sum of the squares of the elements of a tensor -template < - typename Element, - typename Layout, - typename ComputeType = Element -> -ComputeType TensorSumSq( - TensorView view, - ComputeType identity = ComputeType(), - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - plus reduce; - magnitude_squared transform; - - return TensorTransformReduce( - view, identity, reduce, transform, stream, workspace_size); -} - -/// Helper to compute the norm of the elements of a tensor. -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorNorm( - TensorView view, - ComputeType identity = ComputeType(), - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to compute the sum of the squares of the differences of two tensors -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorSumSqDiff( - TensorView view_A, - TensorView view_B, - ComputeType identity = ComputeType(), - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - plus reduce; - magnitude_squared_difference transform; - - return TensorTransformReduce( - view_A, view_B, identity, reduce, transform, stream, workspace_size); -} - - -/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorNormDiff( - TensorView view_A, - TensorView view_B, - ComputeType identity = ComputeType(), - cudaStream_t stream = nullptr, - int workspace_size = 0 -) { - - return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h deleted file mode 100644 index 0e3d99ddf845810249f909fbdee4505a0a732c4f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h +++ /dev/null @@ -1,141 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines device-side elementwise operations on TensorView. Note, the operations defined - in this header are not specialized for any particular data layout and are therefore not - intended to offer the best possible performance. Rather, they are intended to be generic - reference implementations to support the CUTLASS unit tests. -*/ - -#pragma once - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/tensor_view.h" - -#include "cutlass/util/reference/device/tensor_foreach.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace device { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorReLuFunc { - - /// View type - using TensorView = TensorView; - - /// Coordinate in tensor's index space - using TensorCoord = typename TensorView::TensorCoord; - - /// Parameters structure - struct Params { - - // - // Data members - // - - TensorView view; - Element threshold; - - - // - // Methods - // - - Params( - TensorView view_ = TensorView(), - Element threshold_ = Element(0) - ): - view(view_), threshold(threshold_) { - - } - }; - - // - // Data members - // - - Params params; - - // - // Methods - // - - CUTLASS_DEVICE - TensorReLuFunc(Params const ¶ms): params(params) { - - } - - CUTLASS_DEVICE - void operator()(TensorCoord const &coord) { - - Element const & value = params.view.at(coord); - params.view.at(coord) = (value < params.threshold) ? params.threshold : value; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Apply ReLu on a tensor -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorReLu( - TensorView view, ///< destination tensor - Element threshold = Element(0)) { ///< ReLu threshold - - using Func = detail::TensorReLuFunc; - using Params = typename Func::Params; - - TensorForEach( - view.extent(), - Params(view, threshold) - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h deleted file mode 100644 index dd11f96bd92f6995590e61665e41a3e830bceacd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h +++ /dev/null @@ -1,186 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace device { -namespace thread { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Thread-level blocked general matrix product. -// -// Note, this is a reference implementation. Performance is not expected to approach peak. -// -template < - typename TensorRefA, - typename TensorRefB, - typename TensorRefC, - typename ScalarType, - typename AccumulatorType, - typename OutputTile, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -struct Gemm { - - using ElementA = typename TensorRefA::Element; - using ElementB = typename TensorRefB::Element; - using ElementC = typename TensorRefC::Element; - - // - // Data members - // - - /// Tile for A operand - ElementA A_tile[OutputTile::kColumn]; - - /// Tile for B operand - ElementB B_tile[OutputTile::kRow]; - - /// Tile for Accumulator - AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; - - // - // Methods - // - - /// Constructor - CUTLASS_HOST_DEVICE - Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { - - // Clear fetch registers - for (int i = 0; i < OutputTile::kColumn; ++i) { - A_tile[i] = ElementA(0); - } - - for (int j = 0; j < OutputTile::kRow; ++j) { - B_tile[j] = ElementB(0); - } - - // Clear accumulators - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < OutputTile::kColumn; ++j) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < OutputTile::kRow; ++i) { - accum[j][i] = initial_accum; - } - } - } - - /// Computes a matrix product - CUTLASS_HOST_DEVICE - Gemm & multiply_add( - gemm::GemmCoord problem_size, - TensorRefA tensor_a, - TensorRefB tensor_b, - MatrixCoord output_coord = MatrixCoord()) { - - InnerProductOp inner_product_op; - - // Loop over the GEMM K dimension - CUTLASS_PRAGMA_NO_UNROLL - for (int k = 0; k < problem_size.k(); ++k) { - - // Fetch a slice of the A matrix - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < OutputTile::kColumn; ++i) { - if (output_coord.row() + i < problem_size.m()) { - A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); - } - } - - // Fetch a slice of the B matrix - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < OutputTile::kRow; ++j) { - if (output_coord.column() + j < problem_size.n()) { - B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); - } - } - - // Compute an accumulated matrix product - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < OutputTile::kRow; ++j) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < OutputTile::kColumn; ++i) { - accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); - } - } - } - - return *this; - } - - /// Performs linear scaling of matrix product and updates output tensor - CUTLASS_HOST_DEVICE - Gemm & epilogue( - gemm::GemmCoord problem_size, - ScalarType alpha, - ScalarType beta, - TensorRefC tensor_c, - TensorRefC tensor_d, - MatrixCoord output_coord = MatrixCoord()) { - - ConvertOp convert_op; - - // Update the output tensor - for (int j = 0; j < OutputTile::kRow; ++j) { - for (int i = 0; i < OutputTile::kColumn; ++i) { - MatrixCoord coord = output_coord + MatrixCoord(i, j); - if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { - - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[j][i]) + - beta * ScalarType(tensor_c.at(coord)) - ); - } - } - } - - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace device -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp deleted file mode 100644 index 57443325629ea4e5d855fe18f94c73b10a71a73a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp +++ /dev/null @@ -1,782 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for CONV in host-side code. -*/ -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" - -#include "cute/tensor.hpp" - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -bool -is_activation_in_bounds( - cute::Tensor const& activation, - int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { - return ((g_ >= 0 && g_ < size<5>(activation)) && - (n_ >= 0 && n_ < size<4>(activation)) && - (d_ >= 0 && d_ < size<3>(activation)) && - (h_ >= 0 && h_ < size<2>(activation)) && - (w_ >= 0 && w_ < size<1>(activation)) && - (c_ >= 0 && c_ < size<0>(activation))); -} - -template -bool -is_activation_in_bounds( - cute::Tensor const& activation, - int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { - return ((g_ >= 0 && g_ < size<4>(activation)) && - (n_ >= 0 && n_ < size<3>(activation)) && - (h_ >= 0 && h_ < size<2>(activation)) && - (w_ >= 0 && w_ < size<1>(activation)) && - (c_ >= 0 && c_ < size<0>(activation))); -} - -template -bool -is_activation_in_bounds( - cute::Tensor const& activation, - int32_t n_, int32_t w_, int32_t c_, int32_t g_) { - return ((g_ >= 0 && g_ < size<3>(activation)) && - (n_ >= 0 && n_ < size<2>(activation)) && - (w_ >= 0 && w_ < size<1>(activation)) && - (c_ >= 0 && c_ < size<0>(activation))); -} - -} // namespace detail - -template< - class ElementAcc_, - class ElementScalar_, - class ElementCompute_, - class ElementC_, - class ElementOut_, - bool ResidualAdd_, - class TensorAlpha_, - class TensorBeta_, - class TensorBias_, - class ActivationFunctor_ = cutlass::epilogue::thread::Identity -> -struct ConvEpilogueFusionParams { - using ElementAcc = ElementAcc_; - using ElementScalar = ElementScalar_; - using ElementCompute = ElementCompute_; - using ElementC = ElementC_; - using ElementOut = ElementOut_; - using TensorAlpha = TensorAlpha_; - using TensorBeta = TensorBeta_; - using TensorBias = TensorBias_; - using ActivationFunctor = ActivationFunctor_; - static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation - - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorAlpha tensor_alpha{}; - TensorBeta tensor_beta{}; - TensorBias tensor_bias{}; -}; - -template< - cutlass::conv::Operator ConvOp, - int NumSpatialDims, - class TensorA, - class TensorB, - class TensorC, - class TensorD, - class ShapePadding, - class StrideTraversal, - class ShapeDilation, - class EpilogueFusionParams -> -struct ConvReferenceImpl { - // Hard code accumlulator type to float to avoid data lost in accumulating add. - using ElementAcc = cutlass::platform::conditional_t, double, float>; - using ElementC = typename EpilogueFusionParams::ElementC; - using ElementOut = typename EpilogueFusionParams::ElementOut; - using ElementScalar = typename EpilogueFusionParams::ElementScalar; - using ElementCompute = typename EpilogueFusionParams::ElementCompute; - using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; - using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; - - // Input related converter - NumericConverter acc_converter; - NumericConverter residual_converter; - NumericConverter bias_converter; - // Scale related converter - NumericConverter scale_converter; - // Output related converter - NumericConverter output_converter; - - EpilogueFusionParams& epi_fusion_params_; - TensorA const& tensor_a_; - TensorB const& tensor_b_; - TensorC const& tensor_c_; - TensorD& tensor_d_; - - ShapePadding const& padding_; - StrideTraversal const& tstride_; - ShapeDilation const& dilation_; - - // Epilogue activation operation - ActivationFunctor epi_activation; - - ConvReferenceImpl( - TensorA const& tensor_a, - TensorB const& tensor_b, - TensorC const& tensor_c, - TensorD& tensor_d, - ShapePadding const& padding, - StrideTraversal const& tstride, - ShapeDilation const& dilation, - EpilogueFusionParams& epi_fusion_params) - : tensor_a_(tensor_a), - tensor_b_(tensor_b), - tensor_c_(tensor_c), - tensor_d_(tensor_d), - padding_(padding), - tstride_(tstride), - dilation_(dilation), - epi_fusion_params_(epi_fusion_params) - { - static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); - static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); - } - - void compute_reference() { - if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { - fprop_reference(cute::Int{}); - } - else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { - dgrad_reference(cute::Int{}); - } - else { - wgrad_reference(cute::Int{}); - } - } - -private: - // Specialization for 1D fprop kernel - void fprop_reference(cute::Int<1> spatial_dims) { - int32_t G = size<3>(tensor_d_); - int32_t N = size<2>(tensor_d_); - int32_t Q = size<1>(tensor_d_); - int32_t K = size<0>(tensor_d_); - int32_t S = size<1>(tensor_b_); - int32_t C = size<0>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(2) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t q = 0; q < Q; ++q) { - for (int32_t k = 0; k < K; ++k) { - auto accumulator = ElementAcc(0); - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) { - auto a = tensor_a_(c, w, n, g); - auto b = tensor_b_(c, s, k, g); - accumulator += ElementAcc(a * b); - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[k]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); - } - tensor_d_(k, q, n, g) = output_converter(output); - } - } - } - } - - } - - // Specialization for 2D fprop kernel - void fprop_reference(cute::Int<2> spatial_dims) { - int32_t G = size<4>(tensor_d_); - int32_t N = size<3>(tensor_d_); - int32_t P = size<2>(tensor_d_); - int32_t Q = size<1>(tensor_d_); - int32_t K = size<0>(tensor_d_); - int32_t R = size<2>(tensor_b_); - int32_t S = size<1>(tensor_b_); - int32_t C = size<0>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t p = 0; p < P; ++p) { - for (int32_t q = 0; q < Q; ++q) { - for (int32_t k = 0; k < K; ++k) { - auto accumulator = ElementAcc(0); - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); - if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) { - auto a = tensor_a_(c, w, h, n, g); - auto b = tensor_b_(c, s, r, k, g); - accumulator += ElementAcc(a * b); - } - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[k]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); - } - tensor_d_(k, q, p, n, g) = output_converter(output); - } - } - } - } - } - - } - - // Specialization for 3D fprop kernel - void fprop_reference(cute::Int<3> spatial_dims) { - int32_t G = size<5>(tensor_d_); - int32_t N = size<4>(tensor_d_); - int32_t Z = size<3>(tensor_d_); - int32_t P = size<2>(tensor_d_); - int32_t Q = size<1>(tensor_d_); - int32_t K = size<0>(tensor_d_); - int32_t T = size<3>(tensor_b_); - int32_t R = size<2>(tensor_b_); - int32_t S = size<1>(tensor_b_); - int32_t C = size<0>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t z = 0; z < Z; ++z) { - for (int32_t p = 0; p < P; ++p) { - for (int32_t q = 0; q < Q; ++q) { - for (int32_t k = 0; k < K; ++k) { - auto accumulator = ElementAcc(0); - for (int32_t t = 0; t < T; ++t) { - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); - int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); - if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) { - auto a = tensor_a_(c, w, h, d, n, g); - auto b = tensor_b_(c, s, r, t, k, g); - accumulator += ElementAcc(a * b); - } - } - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[k]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); - } - tensor_d_(k, q, p, z, n, g) = output_converter(output); - } - } - } - } - } - } - - } - - // Specialization for 1D dgrad kernel - void dgrad_reference(cute::Int<1> spatial_dims) { - int32_t G = size<3>(tensor_d_); - int32_t N = size<2>(tensor_d_); - int32_t W = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - int32_t K = size<2>(tensor_b_); - int32_t S = size<1>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(2) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t w = 0; w < W; ++w) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t k = 0; k < K; ++k) { - for (int32_t s = 0; s < S; ++s) { - int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); - - if (q % cute::get<0>(tstride_) == 0) { - q /= cute::get<0>(tstride_); - } else { - continue; - } - - if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) { - accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g)); - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) - ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) - ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); - } - tensor_d_(c, w, n, g) = output_converter(output); - } - } - } - } - - } - - // Specialization for 2D dgrad kernel - void dgrad_reference(cute::Int<2> spatial_dims) { - int32_t G = size<4>(tensor_d_); - int32_t N = size<3>(tensor_d_); - int32_t H = size<2>(tensor_d_); - int32_t W = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - int32_t K = size<3>(tensor_b_); - int32_t R = size<2>(tensor_b_); - int32_t S = size<1>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t h = 0; h < H; ++h) { - for (int32_t w = 0; w < W; ++w) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t k = 0; k < K; ++k) { - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); - int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); - - if (q % cute::get<0>(tstride_) == 0) { - q /= cute::get<0>(tstride_); - } else { - continue; - } - - if (p % cute::get<1>(tstride_) == 0) { - p /= cute::get<1>(tstride_); - } else { - continue; - } - - if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) { - accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g)); - } - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) - ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) - ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); - } - - tensor_d_(c, w, h, n, g) = output_converter(output); - } - } - } - } - } - - } - - // Specialization for 3D dgrad kernel - void dgrad_reference(cute::Int<3> spatial_dims) { - int32_t G = size<5>(tensor_d_); - int32_t N = size<4>(tensor_d_); - int32_t D = size<3>(tensor_d_); - int32_t H = size<2>(tensor_d_); - int32_t W = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - int32_t K = size<4>(tensor_b_); - int32_t T = size<3>(tensor_b_); - int32_t R = size<2>(tensor_b_); - int32_t S = size<1>(tensor_b_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t n = 0; n < N; ++n) { - for (int32_t d = 0; d < D; ++d) { - for (int32_t h = 0; h < H; ++h) { - for (int32_t w = 0; w < W; ++w) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t k = 0; k < K; ++k) { - for (int32_t t = 0; t < T; ++t) { - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); - int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); - int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); - - if (q % cute::get<0>(tstride_) == 0) { - q /= cute::get<0>(tstride_); - } else { - continue; - } - - if (p % cute::get<1>(tstride_) == 0) { - p /= cute::get<1>(tstride_); - } else { - continue; - } - - if (z % cute::get<2>(tstride_) == 0) { - z /= cute::get<2>(tstride_); - } else { - continue; - } - - if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) { - accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g)); - } - } - } - } - } - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) - ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) - ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); - } - tensor_d_(c, w, h, d, n, g) = output_converter(output); - } - } - } - } - } - } - - } - - // Specialization for 1D wgrad kernel - void wgrad_reference(cute::Int<1> spatial_dims) { - int32_t G = size<3>(tensor_d_); - int32_t N = - size<2>(tensor_a_); - int32_t Q = - size<1>(tensor_a_); - int32_t K = - size<0>(tensor_a_); - int32_t S = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(2) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t k = 0; k < K; ++k) { - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t n = 0; n < N; ++n) { - for (int32_t q = 0; q < Q; ++q) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - bool is_in_bounds = - detail::is_activation_in_bounds(tensor_b_, n, w, c, g); - if (is_in_bounds) { - auto act = - tensor_b_(c, w, n, g); - auto xformed_act = - tensor_a_(k, q, n, g); - accumulator += ElementAcc(act * xformed_act); - } - } - } - - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); - } - tensor_d_(c, s, k, g) = output_converter(output); - } - } - } - } - } - - // Specialization for 2D wgrad kernel - void wgrad_reference(cute::Int<2> spatial_dims) { - int32_t G = size<4>(tensor_d_); - int32_t N = - size<3>(tensor_a_); - int32_t P = - size<2>(tensor_a_); - int32_t Q = - size<1>(tensor_a_); - int32_t K = - size<0>(tensor_a_); - int32_t R = size<2>(tensor_d_); - int32_t S = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0; g < G; ++g) { - for (int32_t k = 0; k < K; ++k) { - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t n = 0; n < N; ++n) { - for (int32_t p = 0; p < P; ++p) { - for (int32_t q = 0; q < Q; ++q) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); - bool is_in_bounds = - detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g); - if (is_in_bounds) { - auto act = - tensor_b_(c, w, h, n, g); - auto xformed_act = - tensor_a_(k, q, p, n, g); - accumulator += ElementAcc(act * xformed_act); - } - } - } - } - - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); - } - tensor_d_(c, s, r, k, g) = output_converter(output); - } - } - } - } - } - } - - // Specialization for 3D wgrad kernel - void wgrad_reference(cute::Int<3> spatial_dims) { - int32_t G = size<5>(tensor_d_); - int32_t N = - size<4>(tensor_a_); - int32_t Z = - size<3>(tensor_a_); - int32_t P = - size<2>(tensor_a_); - int32_t Q = - size<1>(tensor_a_); - int32_t K = - size<0>(tensor_a_); - int32_t T = size<3>(tensor_d_); - int32_t R = size<2>(tensor_d_); - int32_t S = size<1>(tensor_d_); - int32_t C = size<0>(tensor_d_); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int32_t g = 0 ; g < G; ++g) { - for (int32_t k = 0; k < K; ++k) { - for (int32_t t = 0; t < T; ++t) { - for (int32_t r = 0; r < R; ++r) { - for (int32_t s = 0; s < S; ++s) { - for (int32_t c = 0; c < C; ++c) { - auto accumulator = ElementAcc(0); - for (int32_t n = 0; n < N; ++n) { - for (int32_t z = 0; z < Z; ++z) { - for (int32_t p = 0; p < P; ++p) { - for (int32_t q = 0; q < Q; ++q) { - int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); - int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); - bool is_in_bounds = - detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g); - if (is_in_bounds) { - auto act = - tensor_b_(c, w, h, d, n, g); - auto xformed_act = - tensor_a_(k, q, p, z, n, g); - accumulator += ElementAcc(act * xformed_act); - } - } - } - } - } - - ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? - epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; - ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? - epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); - if (not EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); - } - if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { - output += bias_converter(epi_fusion_params_.tensor_bias[c]); - } - output = epi_activation(output); - if (EpilogueFusionParams::ResidualAdd) { - output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); - } - tensor_d_(c, s, r, t, k, g) = output_converter(output); - } - } - } - } - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h deleted file mode 100644 index 73298e5794f0f2658ef18fb3f46466c400fc831e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +++ /dev/null @@ -1,802 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Reference implementation for convolution in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/functional.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/conv/convolution.h" -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/conv3d_problem_size.h" -#include - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Forward propagation -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// y = conv2d(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv2dFprop( - conv::Conv2dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - // Apply MMA and accumulate ElementAccumulator - for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { - - int group_idx = k / (problem_size.K / problem_size.groups); - int channels_per_group = problem_size.C / problem_size.groups; - - ElementAccumulator acc = ElementAccumulator(); - - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int c = 0; c < channels_per_group; ++c) { - - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { - - ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); - ElementB b = tensor_w.at({k, r, s, c}); - - acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); - - } - } - } - } - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); - } - - tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - } - } - } - } -} - -/// Depthwise-separable convolution -template , - typename InnerProductOp = multiply_add> -void Depsep_Fprop(cutlass::TensorView tensor_A, - cutlass::TensorView tensor_B, - cutlass::TensorView tensor_C, - cutlass::TensorView tensor_D, - ElementCompute alpha, - ElementCompute beta, - cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), - cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), - cutlass::Coord<2> dilation = cutlass::Coord<2>(), - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - // Apply MMA and accumulate ElementAccumulator - for (int n = 0; n < tensor_C.extent().n(); ++n) { - for (int p = 0; p < tensor_C.extent().h(); ++p) { - for (int q = 0; q < tensor_C.extent().w(); ++q) { - for (int g = 0; g < tensor_C.extent().c(); ++g) { - ElementAccumulator acc = ElementAccumulator(); - for (int r = 0; r < tensor_B.extent().h(); ++r) { - for (int s = 0; s < tensor_B.extent().w(); ++s) { - - // input activation H and W - int h = p * conv_stride[0] - padding[0] + r * dilation[0]; - int w = q * conv_stride[1] - padding[2] + s * dilation[1]; - - if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { - ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); - - ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) - ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) - : tensor_B.at(cutlass::make_Coord( - g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); - - acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); - } - } - } - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); - tensor_D.at(cutlass::make_Coord(n, p, q, g)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad / Deconv -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv2dDgrad( - cutlass::conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta, - bool is_deconv = false) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - // Apply MMA and accumulate ElementAccumulator - for (int n = 0; n < problem_size.N; ++n) { - for (int h = 0; h < problem_size.H; ++h) { - for (int w = 0; w < problem_size.W; ++w) { - for (int c = 0; c < problem_size.C; ++c) { - - ElementAccumulator acc = ElementAccumulator(); - - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int k = 0; k < problem_size.K; ++k) { - - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; - int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; - - if (p >= 0 && (p % problem_size.stride_h) == 0 && - q >= 0 && (q % problem_size.stride_w) == 0) { - - p = p / problem_size.stride_h; - q = q / problem_size.stride_w; -#if 0 - std::cout << "row:" - << n * problem_size.H * problem_size.W + - h * problem_size.W + - w << " " - << "n, p, q: (" - << n << ", " - << p << ", " - << q << ") * " - << "r, s: (" - << r << ", " - << s << ") [" - << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" - << std::endl; -#endif - if (p < problem_size.P && q < problem_size.Q) { - - ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); - ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) - : tensor_w.at(cutlass::make_Coord(k, r, s, c)); - - acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); - } - } - - } // for (K) - } // for (S) - } // for (R) - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); - } - - tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - - } // for (C) - } // for (W) - } // for (H) - } // for (N) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Wgrad -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv2dWgrad( - cutlass::conv::Conv2dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta) { - - InnerProductOp inner_product_op; - ConvertOp convert_op; - - // Apply MMA and accumulate ElementAccumulator - for (int k = 0; k < problem_size.K; ++k) { - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int c = 0; c < problem_size.C; ++c) { - - ElementAccumulator acc = ElementAccumulator(); - - for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - - cutlass::Tensor4DCoord b_coord; - - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - b_coord = make_Coord( - n, - p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, - q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, - c); - - if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && - b_coord.w() < problem_size.W && b_coord.w() >= 0) { - - ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); - ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); - acc = inner_product_op(a, b, acc); - } - } - } - } - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); - } - - tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - - } // for (C) - } // for (S) - } // for (R) - } // for (K) -} - -/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv2d( - conv::Operator convolutional_operator, - conv::Conv2dProblemSize problem_size, - TensorRef tensor_A, - TensorRef tensor_B, - TensorRef tensor_C, - TensorRef tensor_D, - ElementCompute alpha, - ElementCompute beta) { - - switch (convolutional_operator) { - case conv::Operator::kFprop: - Conv2dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ElementD, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); - break; - - case conv::Operator::kDeconv: - case conv::Operator::kDgrad: - Conv2dDgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ElementD, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); - break; - - case conv::Operator::kWgrad: - Conv2dWgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ElementD, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); - break; - - default: - break; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// 3D convolution -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// y = conv3d(x, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv3dFprop( - conv::Conv3dProblemSize problem_size, - TensorRef tensor_x, - TensorRef tensor_w, - TensorRef tensor_y_in, - TensorRef tensor_y_out, - ElementCompute alpha, - ElementCompute beta) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - // Apply MMA and accumulate ElementAccumulator - for (int n = 0; n < problem_size.N; ++n) { - for (int z = 0; z < problem_size.Z; ++z) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { - - ElementAccumulator acc = ElementAccumulator(); - - for (int t = 0; t < problem_size.T; ++t) { - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int c = 0; c < problem_size.C; ++c) { - - int filter_t = t; - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - t; - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; - int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; - int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; - - if (d >= 0 && d < problem_size.D && - h >=0 && h < problem_size.H && - w >= 0 && w < problem_size.W) { - - ElementA a = tensor_x.at({n, d, h, w, c}); - ElementB b = tensor_w.at({k, t, r, s, c}); - - acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); - } - } - } - } - } - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); - } - - tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad / Deconv -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// dx = dgrad(dy, w) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv3dDgrad( - cutlass::conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_w, - TensorRef tensor_dx_in, - TensorRef tensor_dx_out, - ElementCompute alpha, - ElementCompute beta, - bool is_deconv = false) { - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - // Apply MMA and accumulate ElementAccumulator - for (int n = 0; n < problem_size.N; ++n) { - for (int d = 0; d < problem_size.D; ++d) { - for (int h = 0; h < problem_size.H; ++h) { - for (int w = 0; w < problem_size.W; ++w) { - for (int c = 0; c < problem_size.C; ++c) { - - ElementAccumulator acc = ElementAccumulator(); - - for (int t = 0; t < problem_size.T; ++t) { - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int k = 0; k < problem_size.K; ++k) { - - int filter_t = t; - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - t; - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; - int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; - int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; - - if (z >= 0 && (z % problem_size.stride_d) == 0 && - p >= 0 && (p % problem_size.stride_h) == 0 && - q >= 0 && (q % problem_size.stride_w) == 0) { - - z = z / problem_size.stride_d; - p = p / problem_size.stride_h; - q = q / problem_size.stride_w; - - if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { - - ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); - ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) - : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); - acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); - } - } - - } // for (K) - } // for (S) - } // for (R) - } // for (T) - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); - } - - tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - - } // for (C) - } // for (W) - } // for (H) - } // for (D) - } // for (N) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Wgrad -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// dw = wgrad(dy, x) -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv3dWgrad( - cutlass::conv::Conv3dProblemSize problem_size, - TensorRef tensor_dy, - TensorRef tensor_x, - TensorRef tensor_dw_in, - TensorRef tensor_dw_out, - ElementCompute alpha, - ElementCompute beta) { - - InnerProductOp inner_product_op; - ConvertOp convert_op; - - // Apply MMA and accumulate ElementAccumulator - for (int k = 0; k < problem_size.K; ++k) { - for (int t = 0; t < problem_size.T; ++t) { - for (int r = 0; r < problem_size.R; ++r) { - for (int s = 0; s < problem_size.S; ++s) { - for (int c = 0; c < problem_size.C; ++c) { - - ElementAccumulator acc = ElementAccumulator(); - - for (int n = 0; n < problem_size.N; ++n) { - for (int z = 0; z < problem_size.Z; ++z) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - - int filter_t = t; - int filter_r = r; - int filter_s = s; - - if (problem_size.mode == cutlass::conv::Mode::kConvolution) { - filter_t = problem_size.T - 1 - t; - filter_r = problem_size.R - 1 - r; - filter_s = problem_size.S - 1 - s; - } - - Tensor5DCoord b_coord = make_Coord( - n, - z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, - p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, - q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, - c); - - if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && - b_coord.h() < problem_size.H && b_coord.h() >= 0 && - b_coord.w() < problem_size.W && b_coord.w() >= 0) { - - ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); - ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); - - acc = inner_product_op(a, b, acc); - } - } - } - } - } - - // Apply Epilogue, compute ElementCompute, convert and store ElementC - ElementC c_ref = ElementC(); - - if (beta != ElementCompute()) { - c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); - } - - tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = - convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); - - } // for (C) - } // for (S) - } // for (R) - } // for (T) - } // for (K) -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ElementCompute, - typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Conv3d( - conv::Operator convolutional_operator, - conv::Conv3dProblemSize problem_size, - TensorRef tensor_A, - TensorRef tensor_B, - TensorRef tensor_C, - TensorRef tensor_D, - ElementCompute alpha, - ElementCompute beta) { - - switch (convolutional_operator) { - case conv::Operator::kFprop: - Conv3dFprop< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); - break; - - case conv::Operator::kDeconv: - case conv::Operator::kDgrad: - Conv3dDgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); - break; - - case conv::Operator::kWgrad: - Conv3dWgrad< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, - ElementAccumulator, - ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); - break; - - default: - break; - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h deleted file mode 100644 index 12ead83354b785096e8029b49f1ac353d5ce5f82..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h +++ /dev/null @@ -1,66 +0,0 @@ - -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/util/reference/host/tensor_reduce.h" -#include "cutlass/core_io.h" - -namespace cutlass { -namespace reference { -namespace host { - -/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorRelativeErrorMetric( - TensorView view_A_computed, - TensorView view_B_reference, - ComputeType identity = ComputeType() -) { - - return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / - cutlass::reference::host::TensorNorm(view_B_reference, identity); -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h deleted file mode 100644 index 2afee7b36d9822cc196f0f167f9dbec4c295d1a6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h +++ /dev/null @@ -1,531 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GEMM in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/mma.h" -#include "cutlass/util/host_tensor.h" - -namespace cutlass { -namespace reference { -namespace host { - -template -struct CastIfScalar { - static Out cast(In in) { - return Out(in); - } -}; - -template -struct CastIfScalar, In> { - typedef cutlass::complex Out; - static Out cast(In in) { - return Out(static_cast(in)); - } -}; - -template -struct CastIfScalar, cutlass::complex> { - typedef cutlass::complex Out; - typedef cutlass::complex In; - static Out cast(In in) { - return Out(in); - } -}; - -template -Out cast_if_scalar(In in) { - return CastIfScalar::cast(in); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_gemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b = tensor_b.at(MatrixCoord(k_block, col)); - - ComputeType compute_a(cast_if_scalar(a)); - ComputeType compute_b(cast_if_scalar(b)); - - accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_gemm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum) { - compute_gemm( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, - initial_accum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = cutlass::arch::OpMultiplyAdd -> -struct Gemm; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add-saturate -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm, - NumericConverterClamp>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm, - NumericConverterClamp>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for XOR-popc -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -/// Partial specialization for AND-popc -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Gemm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_gemm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Batched GEMM -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a batch of GEMMs over a set of matrices of common dimension. -// -// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -// -template < - typename TensorRefCollectionA, - typename TensorRefCollectionB, - typename TensorRefCollectionC, - typename ScalarType, - typename AccumulatorType -> -void BatchedGemm( - gemm::GemmCoord problem_size, - int batch_count, - ScalarType alpha, - TensorRefCollectionA const& tensor_a, - TensorRefCollectionB const& tensor_b, - ScalarType beta, - TensorRefCollectionC &tensor_c, - AccumulatorType initial_accum) { - - typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); - typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); - typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); - - for (int batch = 0; - batch < batch_count; - ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { - - Gemm - gemm; - - gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, - initial_accum); - } -} - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -// -// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -// -template < - typename TensorRefCollectionA, - typename TensorRefCollectionB, - typename TensorRefCollectionC, - typename ScalarType, - typename AccumulatorType -> -void BatchedGemm( - gemm::GemmCoord problem_size, - int batch_count, - ScalarType alpha, - TensorRefCollectionA const& tensor_a, - TensorRefCollectionB const& tensor_b, - ScalarType beta, - TensorRefCollectionC &tensor_c) { - - BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h deleted file mode 100644 index 221a6040854a74ce465af7b021bbbfae9b96a90b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h +++ /dev/null @@ -1,210 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued GEMM in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/matrix_coord.h" - -#include "cutlass/tensor_view.h" - -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ElementD = ElementC, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void GemmComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { - - // Compute matrix product using blocks - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b = tensor_b.at(MatrixCoord(k_block, col)); - - ComputeType a_ik = ComputeType(a); - ComputeType b_kj = ComputeType(b); - - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } - - if (transform_b == ComplexTransform::kConjugate) { - b_kj = conj(b_kj); - } - - accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); - } - } - } - - } // for (col_block) - } // for (row_block) - - tensor_a.add_pointer_offset(batch_stride_A); - tensor_b.add_pointer_offset(batch_stride_B); - tensor_c.add_pointer_offset(batch_stride_C); - tensor_d.add_pointer_offset(batch_stride_D); - - } // for (batch_idx) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ElementD = ElementC -> -void GemmComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d) { - - GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h deleted file mode 100644 index 507c37d9eb5a8c998f1075d547e8430b2edc5685..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +++ /dev/null @@ -1,228 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued GEMM in host-side code. -*/ - -#pragma once - -#include "cutlass/coord.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_ref_planar_complex.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add> -> -void GemmPlanarComplex( - gemm::GemmCoord problem_size, - complex alpha, - TensorRefPlanarComplex tensor_a, - ComplexTransform transform_a, - TensorRefPlanarComplex tensor_b, - ComplexTransform transform_b, - complex beta, - TensorRefPlanarComplex tensor_c, - TensorRefPlanarComplex tensor_d, - complex initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - using ComplexA = typename TensorRefPlanarComplex::ComplexElement; - using ComplexB = typename TensorRefPlanarComplex::ComplexElement; - using ComplexC = typename TensorRefPlanarComplex::ComplexElement; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - complex accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - - ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); - ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); - - complex a = complex{ - ComputeType(a_ik.real()), - ComputeType(a_ik.imag()) - }; - - complex b = complex{ - ComputeType(b_kj.real()), - ComputeType(b_kj.imag()) - }; - - if (transform_a == ComplexTransform::kConjugate) { - a = conj(a); - } - - if (transform_b == ComplexTransform::kConjugate) { - b = conj(b); - } - - accum[i][j] = inner_product_op(a, b, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - - complex acc{ - ScalarType(accum[i][j].real()), - ScalarType(accum[i][j].imag()) - }; - - ComplexC d_ij = tensor_c.at(coord); - - complex src{ - ScalarType(d_ij.real()), - ScalarType(d_ij.imag()) - }; - - complex result = alpha * acc + beta * src; - - d_ij.real() = convert_op(result.real()); - d_ij.imag() = convert_op(result.imag()); - - tensor_d.at(coord) = d_ij; - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType -> -void GemmPlanarComplex( - gemm::GemmCoord problem_size, - complex alpha, - TensorRefPlanarComplex tensor_a, - ComplexTransform transform_a, - TensorRefPlanarComplex tensor_b, - ComplexTransform transform_b, - complex beta, - TensorRefPlanarComplex tensor_c, - TensorRefPlanarComplex tensor_d) { - - GemmPlanarComplex( - problem_size, - alpha, - tensor_a, transform_a, - tensor_b, transform_b, - beta, - tensor_c, - tensor_d, - complex()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp deleted file mode 100644 index dd54dc6e378d0d0f0549ec922da8357841ac558f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp +++ /dev/null @@ -1,916 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" - -#include "cute/tensor.hpp" -#include "cute/pointer.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/////////////////////////////////////////////////////////// -// -// Gett Mainloop Parameters -// -/////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_ // (N, K, L) - - , class TensorSfA_ = TensorA_, - class TensorSfB_ = TensorB_ - -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - TensorA A{}; - TensorB B{}; - - ComplexTransform transform_A = ComplexTransform::kNone; - ComplexTransform transform_B = ComplexTransform::kNone; - - - using TensorSfA = TensorSfA_; - using TensorSfB = TensorSfB_; - using EngineSfA = typename TensorSfA::engine_type; - using LayoutSfA = typename TensorSfA::layout_type; - using EngineSfB = typename TensorSfB::engine_type; - using LayoutSfB = typename TensorSfB::layout_type; - TensorSfA_ SfA{}; - TensorSfB_ SfB{}; - - - GettMainloopParams() {} - - GettMainloopParams(TensorA tensor_A, TensorB tensor_B) - : A(tensor_A), B(tensor_B) {} - - - GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) - : A(tensor_A), SfA(tensor_SfA), - B(tensor_B), SfB(tensor_SfB) {} - - -}; - - - -//////////////////////////////////////////////////////////////////////// -// -// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels -// -//////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorSfA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorSfB_ // (N, K, L) -> -struct GettBlockScalingMainloopParams : public GettMainloopParams { - using Base = GettMainloopParams; - using ElementAccumulator = typename Base::ElementAccumulator; - using TensorA = typename Base::TensorA; - using TensorB = typename Base::TensorB; - using EngineA = typename Base::EngineA; - using LayoutA = typename Base::LayoutA; - using EngineB = typename Base::EngineB; - using LayoutB = typename Base::LayoutB; - ComplexTransform transform_A = Base::transform_A; - ComplexTransform transform_B = Base::transform_B; - - using TensorSfA = typename Base::TensorSfA; - using TensorSfB = typename Base::TensorSfB; - using EngineSfA = typename Base::EngineSfA; - using LayoutSfA = typename Base::LayoutSfA; - using EngineSfB = typename Base::EngineSfB; - using LayoutSfB = typename Base::LayoutSfB; - - GettBlockScalingMainloopParams() {} - - GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) - : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} - - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -enum class SfStrategy { - None = 0, - SfDGen = 1 -}; - - -/////////////////////////////////////////////////////////// -// -// Gett Epilogue Parameters -// -/////////////////////////////////////////////////////////// - -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) - class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) - class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class TensorSFD_ = TensorD_, - class SFD_VectorSize_ = cute::Int<0>, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false - , - SfStrategy SfGenStrategy_ = SfStrategy::None -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using TensorSFD = TensorSFD_; - using SFD_VectorSize = SFD_VectorSize_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - using EngineSfD = typename TensorSFD::engine_type; - using LayoutSfD = typename TensorSFD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; - - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - TensorSFD SfD{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; - GettEpilogueParams() {} - - GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) - : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} - - - GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) - : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} - - - GettEpilogueParams( - ElementScalar alpha, ElementScalar beta, - TensorC tensor_C, TensorD tensor_D, - VectorBias bias, TensorAux tensor_aux, - VectorAlpha vector_alpha, VectorBeta vector_beta) - : alpha(alpha), beta(beta), - C(tensor_C), D(tensor_D), - Bias(bias), Aux(tensor_aux), - Valpha(vector_alpha), Vbeta(vector_beta) {} -}; - - - -//////////////////////////////////////////////////////////////////////// -// -// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels -// -//////////////////////////////////////////////////////////////////////// - -template< - class ElementScalar_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, - class TensorD_, - class TensorSfD_ = TensorD_, - class SFD_VectorSize_ = cute::Int<0>, - SfStrategy SfGenStrategy_ = SfStrategy::None -> -struct GettBlockScalingEpilogueParams : public GettEpilogueParams< - ElementScalar_, // ElementScalar - ElementScalar_, // ElementScalingFactor - ElementAccumulator_, // ElementAccumulator - ElementCompute_, // ElementCompute - TensorC_, // TensorC (M, N, L) - TensorD_, // TensorD (M, N, L) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) - cutlass::epilogue::thread::Identity, // - TensorSfD_, // TensorSfD - SFD_VectorSize_, // SFD_VectorSize - cutlass::plus, // class BiasBinaryOp_ = - false, //PerColumnBias_ - SfGenStrategy_ // SfGenStrategy - > { - using Base = GettEpilogueParams< - ElementScalar_, // ElementScalar - ElementScalar_, // ElementScalingFactor - ElementAccumulator_, // ElementAccumulator - ElementCompute_, // ElementCompute - TensorC_, // TensorC (M, N, L) - TensorD_, // TensorD (M, N, L) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) - decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) - cutlass::epilogue::thread::Identity, // - TensorSfD_, // TensorSfD - SFD_VectorSize_, // SFD_VectorSize - cutlass::plus, // BiasBinaryOp - false, // PerColumnBias - SfGenStrategy_ // SfGenStrategy - >; - using ElementScalar = typename Base::ElementScalar; - using ElementScalingFactor = typename Base::ElementScalingFactor; - using ElementAccumulator = typename Base::ElementAccumulator; - using ElementCompute = typename Base::ElementCompute; - using TensorC = typename Base::TensorC; - using TensorD = typename Base::TensorD; - using TensorAux = typename Base::TensorAux; - using VectorBias = typename Base::VectorBias; - using VectorAlpha = typename Base::VectorAlpha; - using VectorBeta = typename Base::VectorBeta; - using TensorSFD = typename Base::TensorSFD; - using SFD_VectorSize = typename Base::SFD_VectorSize; - using ActivationFunctor = typename Base::ActivationFunctor; - using BiasBinaryOp = typename Base::BiasBinaryOp; - - using EngineC = typename Base::EngineC; - using LayoutC = typename Base::LayoutC; - using EngineD = typename Base::EngineD; - using LayoutD = typename Base::LayoutD; - using EngineSfD = typename Base::EngineSfD; - using LayoutSfD = typename Base::LayoutSfD; - static constexpr bool PerColumnBias = Base::PerColumnBias; - static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; - - GettBlockScalingEpilogueParams() {} - - GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) - : Base(alpha, beta, tensor_C, tensor_D) {} - - GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) - : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} - - GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) - : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} -}; - - - - - -/////////////////////////////////////////////////////////// -// -// Generic Gett 3x Implementation -// -/////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// -template -void compute_1d_scaling_factor_and_quantized_output( - EpilogueParams const& epilogue_params, - TensorD &tensor_D, - TensorSFD &tensor_SfD, - int64_t m, - int64_t n, - int64_t l, - ElementCompute (&acc)[kBlockM][kBlockN]) -{ - using ElementD = typename ElementTraits::type; - using ElementSfD = typename ElementTraits::type; - - int const M = cute::size<0>(tensor_D.layout()); - int const N = cute::size<1>(tensor_D.layout()); - int const L = cute::size<2>(tensor_D.layout()); - - auto mul = cutlass::multiplies{}; - auto div = divides{}; - // Get FP max - ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); - float scale_down_factor = div(1.0f, fp_max); - // Get st' = st / FP max - ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); - - absolute_value_op abs_op; - maximum_with_nan_propogation max_op; - - if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { - // MN major output - int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); - // Col major output - for (int n_b = 0; n_b < kBlockN; ++n_b) { - for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { - int64_t col = n + n_b; - - /// Step1: get max across a vector - ElementCompute accum_max = ElementCompute(0); - for (int v = 0; v < kVectorSize; v++) { - int accum_row = v_b * kVectorSize + v; - int64_t output_row = accum_row + m; - if (output_row < M && col < N) { - accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); - } - } - - /// Step2: Compute Scale - ElementCompute pvscale = mul(accum_max, st_scaled_down); - ElementSfD qpvscale = static_cast(pvscale); - // Store the Scaling Factors - int64_t sf_row = m + kVectorSize * v_b; - if (sf_row < M && col < N) { - tensor_SfD(sf_row, col, l) = qpvscale; - } - - /// Step3: Compute quantized output values - ElementCompute qpvscale_up = NumericConverter{}(qpvscale); - // Get float reciprocal - ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); - ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); - // Map INF to fp32::max - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - // Store the intermediate_accum - for (int v = 0; v < kVectorSize; v++) { - int accum_row = v_b * kVectorSize + v; - int64_t output_row = accum_row + m; - if (output_row < M && col < N) { - acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); - } - } - } - } - } - else { - int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); - // row major output - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { - int64_t row = m + m_b; - - /// Step1: get max across a vector - ElementCompute accum_max = ElementCompute(0); - for (int v = 0; v < kVectorSize; v++) { - int accum_col = v_b * kVectorSize + v; - int64_t output_col = accum_col + n; - if (row < M && output_col < N) { - accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); - } - } - - /// Step2: Compute Scale - ElementCompute pvscale = mul(accum_max, st_scaled_down); - ElementSfD qpvscale = static_cast(pvscale); - // Store the Scaling Factors - int64_t sf_col = n + kVectorSize * v_b; - - if (row < M && sf_col < N) { - tensor_SfD(row, sf_col, l) = qpvscale; - } - - /// Step3: Compute quantized output values - ElementCompute qpvscale_up = NumericConverter{}(qpvscale); - // Get float reciprocal - ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); - ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); - // Map INF to fp32::max - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - // Store the intermediate_accum - for (int v = 0; v < kVectorSize; v++) { - int accum_col = v_b * kVectorSize + v; - int64_t output_col = accum_col + n; - if (row < M && output_col < N) { - acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); - } - } - } - } - } -} - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = 64; - static int constexpr kBlockN = 64; - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - - - using ElementSFA = typename ElementTraits::type; - using ElementSFB = typename ElementTraits::type; - - - using RingOp = multiply_add; - RingOp fma_op; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - - - if constexpr (not cute::is_same_v){ - // Load SFA - auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); - a_frag[m_b] *= sfa; - } - - - if (mainloop_params.transform_A == ComplexTransform::kConjugate) { - a_frag[m_b] = conj(a_frag[m_b]); - } - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - - - if constexpr (not cute::is_same_v){ - // Load SFB - auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); - b_frag[n_b] *= sfb; - } - - - if (mainloop_params.transform_B == ComplexTransform::kConjugate) { - b_frag[n_b] = conj(b_frag[n_b]); - } - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementSfD = typename EpilogueParams::TensorSFD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; - - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool UseReLU = - cute::is_same_v>; // Treat Clamp as ReLU - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - ActivationFunctor activation; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // vector alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // vector beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); - converted_beta = mul(converted_beta, converted_scale_c); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (UseReLU) { - cutlass::epilogue::thread::ReLU relu; - output = relu(output); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - - if constexpr ( - SfGenStrategy == SfStrategy::SfDGen - ) { - // 1d scale factor generation - constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; - if (epilogue_params.SfD.data() != nullptr) { - compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); - } - } - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -auto make_layout_rank3(const TensorType& tensor) { - // append a batch mode of size 1 if we do not have tensors that are rank 3 - return make_layout( - make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), - make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); -} - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - - if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { - cute::Layout layout_A = make_layout_rank3(mainloop_params.A); - cute::Layout layout_B = make_layout_rank3(mainloop_params.B); - cute::Layout layout_C = make_layout_rank3(epilogue_params.C); - cute::Layout layout_D = make_layout_rank3(epilogue_params.D); - cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); - cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); - cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); - cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); - - auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); - auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); - auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); - auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); - auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); - auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); - auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); - auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); - - // Reconstruct mainloop params - GettMainloopParams - mainloop_params_converted{TensorA, - TensorB, - mainloop_params.transform_A, - mainloop_params.transform_B}; - - // Reconstruct epilogue params - GettEpilogueParams - epilogue_params_converted{epilogue_params.alpha, - epilogue_params.beta, - TensorC, - TensorD, - VectorBias, - TensorAux, - VectorAlpha, - VectorBeta, - epilogue_params.abs_amax_D, - epilogue_params.abs_amax_Aux, - epilogue_params.scale_a, - epilogue_params.scale_b, - epilogue_params.scale_c, - epilogue_params.scale_d, - epilogue_params.scale_aux - }; - - Gett(mainloop_params_converted, epilogue_params_converted); - } - else { - // if we already have a batch mode, just pass it through - Gett(mainloop_params, epilogue_params); - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h deleted file mode 100644 index 67867533d5783b6e0047ac2110dc47adaa277e25..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h +++ /dev/null @@ -1,261 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for Rank 2k update in host-side code. - - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/mma.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - FillMode FillModeC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_rank2k( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - static_assert( - FillModeC == FillMode::kLower || - FillModeC == FillMode::kUpper, - "Fill Mode can either be Lower or Upper."); - - using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), - std::greater_equal, - std::less_equal>::type; - - // Note: batch is ignored. - // Note: M is same as N for Rank 2k update - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - CompareOp compare_op; - - for (int row_block = 0; row_block < N; row_block += Nblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Nblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Nblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Nblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < N && col < N && compare_op(row, col)) - { - - // A x B^T - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); - - ComputeType compute_a(cast_if_scalar(a)); - ComputeType compute_b_t(cast_if_scalar(b_t)); - - accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); - - // B x A^T - ElementB b = tensor_b.at(MatrixCoord(row, k_block)); - ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); - - ComputeType compute_b(cast_if_scalar(b)); - ComputeType compute_a_t(cast_if_scalar(a_t)); - - accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Nblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < N && col < N && - ( (FillModeC == FillMode::kLower && row >= col) || - (FillModeC == FillMode::kUpper && row <= col) ) - ) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - FillMode FillModeC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_rank2k( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum) { - compute_rank2k( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, - initial_accum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - FillMode FillModeC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = cutlass::arch::OpMultiplyAdd -> -struct Rank2K; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Rank2K { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_rank2k>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_rank2k>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h deleted file mode 100644 index a738101660f7ebbdd7c7796d46df244f1e3f5f70..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h +++ /dev/null @@ -1,318 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued Rank 2K update in host-side code. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - FillMode fill_mode_c, - BlasMode blas_mode, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Rank2K update operates on A=NxK, B=NxK, and C=NxN - assert(M==N); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { - - // Compute matrix product using blocks - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N && - ( (fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col) ) - ) { - - // A x B^T (Symmetric) or A x B^H (Hermitian) - // complex conjugation on operandB (b_t) is function of blas3 computation - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b_t = (blas_mode == BlasMode::kHermitian) ? - conj(tensor_b.at(MatrixCoord(col, k_block))) : - tensor_b.at(MatrixCoord(col, k_block)); - - ComputeType a_ik = ComputeType(a); - ComputeType b_jk = ComputeType(b_t); - - // complex conjugation is a function of operand layouts - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } - // complex conjugation is a function of operand layouts - if (transform_b == ComplexTransform::kConjugate) { - b_jk = conj(b_jk); - } - - accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); - } - } - } - } - - /* HER2K need two epilogues to handle complex alpha value */ - if ( blas_mode == BlasMode::kHermitian ) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N && - ((fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col)) - ) { - - ScalarType c = tensor_c.at(coord); - // The imaginary parts of the diagonal elements of - // a complex data type are assumed and set to zero - if (blas_mode == BlasMode::kHermitian) { - c = (row == col) ? real(c) : c; - } - - tensor_d.at(coord) = convert_op(alpha * - ScalarType(accum[i][j]) + - beta * c); - } - } - } - - /* Zeoring out accum for second HERK */ - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N && - ( (fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col) ) - ) { - - // B x A^T (Symmetric) or B x A^H (Hermitian) - // complex conjugation on operandB (a_t) is function of blas3 computation - ElementB b = tensor_b.at(MatrixCoord(row, k_block)); - ElementA a_t = (blas_mode == BlasMode::kHermitian) ? - conj(tensor_a.at(MatrixCoord(col, k_block))): - tensor_a.at(MatrixCoord(col, k_block)); - - ComputeType b_ik = ComputeType(b); - ComputeType a_jk = ComputeType(a_t); - - // complex conjugation here is a function of operand layouts - if (transform_b == ComplexTransform::kConjugate) { - b_ik = conj(b_ik); - } - // complex conjugation here is a function of operand layouts - if (transform_a == ComplexTransform::kConjugate) { - a_jk = conj(a_jk); - } - - accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); - } - } - } - } - - ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? - conj(alpha) : alpha; - ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? - 1 : beta; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N && - ((fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col)) - ) { - - ScalarType d = (blas_mode == BlasMode::kHermitian) ? - tensor_d.at(coord) : tensor_c.at(coord); - - ScalarType tmp_d = convert_op( - alpha_hermitian * ScalarType(accum[i][j]) + - beta_hermitian * d); - - if (blas_mode == BlasMode::kHermitian && row == col ) { - tensor_d.at(coord) = real(tmp_d); - } else { - tensor_d.at(coord) = tmp_d; - } - } - } - } - - } // for (col_block) - } // for (row_block) - - tensor_a.add_pointer_offset(batch_stride_A); - tensor_b.add_pointer_offset(batch_stride_B); - tensor_c.add_pointer_offset(batch_stride_C); - tensor_d.add_pointer_offset(batch_stride_D); - - } // for (batch_idx) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType -> -void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - TensorRef tensor_b, - ComplexTransform transform_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - FillMode fill_mode_c, - BlasMode blas_mode) { - - Rank2KComplex( - problem_size, alpha, - tensor_a, transform_a, - tensor_b, transform_b, - beta, tensor_c, tensor_d, - ScalarType(0), - fill_mode_c, - blas_mode); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h deleted file mode 100644 index 1aad33fd643b60752bc0845e403cebc43ad7d047..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h +++ /dev/null @@ -1,234 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued Rank 2K update in host-side code. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename ConvertOp = NumericConverter, - typename InnerProductOp = multiply_add -> -void Rank2KComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - FillMode fill_mode_c, - BlasMode blas_mode, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static_assert( - LayoutA::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - int const K = problem_size.k(); - - // Rank2K update operates on A=NxK, B=NxK, and C=NxN - assert(M==N); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - - for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { - - // Compute matrix product using blocks - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N && - ( (fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col) ) - ) { - - // A x A^T (Symmetric) or A x A^H (Hermitian) - // complex conjugation on operandB (a_t) (function of blas3 computation) - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementA a_t = (blas_mode == BlasMode::kHermitian) ? - conj(tensor_a.at(MatrixCoord(col, k_block))) : - tensor_a.at(MatrixCoord(col, k_block)); - - ComputeType a_ik = ComputeType(a); - ComputeType b_jk = ComputeType(a_t); - - // complex conjugation (function of input layouts) - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } - // complex conjugation (function of input layouts) - if (transform_a == ComplexTransform::kConjugate) { - b_jk = conj(b_jk); - } - - accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); - - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N && - ((fill_mode_c == FillMode::kLower && row >= col) || - (fill_mode_c == FillMode::kUpper && row <= col)) - ) { - - ScalarType c = tensor_c.at(coord); - // The imaginary parts of the diagonal elements of - // a complex data type are assumed and set to zero - if (blas_mode == BlasMode::kHermitian) { - c = (row == col) ? real(c) : c; - } - - ScalarType tmp_d = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * c); - - if (blas_mode == BlasMode::kHermitian && row == col ) { - tensor_d.at(coord) = real(tmp_d); - } else { - tensor_d.at(coord) = tmp_d; - } - } - } - } - - } // for (col_block) - } // for (row_block) - - tensor_a.add_pointer_offset(batch_stride_A); - tensor_c.add_pointer_offset(batch_stride_C); - tensor_d.add_pointer_offset(batch_stride_D); - - } // for (batch_idx) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// This assumes the accumulator type is the same type as the scalars. -template < - typename ElementA, - typename LayoutA, - typename ElementC, - typename LayoutC, - typename ScalarType -> -void RankKComplex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - ComplexTransform transform_a, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - FillMode fill_mode_c, - BlasMode blas_mode) { - - Rank2KComplex( - problem_size, alpha, - tensor_a, transform_a, - beta, tensor_c, tensor_d, - ScalarType(0), - fill_mode_c, - blas_mode); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h deleted file mode 100644 index 34f9648f25f8965f6730999b7763220c360683a8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h +++ /dev/null @@ -1,285 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for SYMM update in host-side code. - - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/mma.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_symm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - static_assert(SideModeA != SideMode::kInvalid - , "Side Mode can either be Left or Right."); - - static_assert( - FillModeA == FillMode::kLower || - FillModeA == FillMode::kUpper, - "Fill Mode can either be Lower or Upper."); - - using CompareOp_w_diag = typename TrMatrixCompareOp::Type; - using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - // Assuming correct k-dimension value is passed - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - CompareOp_w_diag compare_op_1; - CompareOp_wo_diag compare_op_2; - - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a_1 = ElementA(); - ElementB b_1 = ElementB(); - ElementA a_2 = ElementA(); - ElementB b_2 = ElementB(); - - // A x B or B x A (with diagonal) - if (SideModeA == SideMode::kLeft) { - a_1 = (compare_op_1(row, k_block)) ? - (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); - b_1 = tensor_b.at(MatrixCoord(k_block, col)); - } else if (SideModeA == SideMode::kRight) { - a_1 = tensor_b.at(MatrixCoord(row, k_block)); - b_1 = (compare_op_1(k_block, col)) ? - tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); - } - - ComputeType compute_a_1(cast_if_scalar(a_1)); - ComputeType compute_b_1(cast_if_scalar(b_1)); - - accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); - - // A^T x B or B x A^T (without diagonal) - if (SideModeA == SideMode::kLeft) { - a_2 = (compare_op_2(k_block, row)) ? - (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); - b_2 = tensor_b.at(MatrixCoord(k_block, col)); - } else if (SideModeA == SideMode::kRight) { - a_2 = tensor_b.at(MatrixCoord(row, k_block)); - b_2 = (compare_op_2(col, k_block)) ? - tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); - } - - ComputeType compute_a_2(cast_if_scalar(a_2)); - ComputeType compute_b_2(cast_if_scalar(b_2)); - - accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_symm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum) { - compute_symm( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, - initial_accum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = cutlass::arch::OpMultiplyAdd -> -struct Symm; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Symm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_symm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); - } - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_symm>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h deleted file mode 100644 index 79e146f69b784a92ce61a093f410e93a66005cf8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h +++ /dev/null @@ -1,319 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued SYMM update in host-side code. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include - -namespace cutlass { -namespace reference { -namespace host { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -/// objects. -/// -/// Explicitly naming types needed by this template can be cumbersome, particularly for the -/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -/// AccumulatorType(0) as the last function argument can be easier than naming all template -/// arguments explicitly. -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - BlasMode BlasMode_ = BlasMode::kSymmetric, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_symm_complex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum, - int batch_count = 1, - int64_t batch_stride_A = 0, - int64_t batch_stride_B = 0, - int64_t batch_stride_C = 0, - int64_t batch_stride_D = 0) { - - static SideMode const kSideModeA = SideModeA; - static FillMode const kFillModeA = FillModeA; - static BlasMode const kBlasMode = BlasMode_; - - static_assert( - LayoutA::kRank == 2 && - LayoutB::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - static_assert(kSideModeA != SideMode::kInvalid - , "Side Mode can either be Left or Right."); - - static_assert( - kFillModeA == FillMode::kLower || - kFillModeA == FillMode::kUpper, - "Fill Mode can either be Lower or Upper."); - - using CompareOp_w_diag = typename TrMatrixCompareOp::Type; - using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - // Assuming correct k-dimension value is passed - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - CompareOp_w_diag compare_op_1; - CompareOp_wo_diag compare_op_2; - - for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { - - // Compute matrix product using blocks - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) - { - ElementA a_1 = ElementA(); - ElementB b_1 = ElementB(); - ElementA a_2 = ElementA(); - ElementB b_2 = ElementB(); - - // A x B or B x A (with diagonal) - if (kSideModeA == SideMode::kLeft) { - a_1 = (compare_op_1(row, k_block)) ? - (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); - b_1 = tensor_b.at(MatrixCoord(k_block, col)); - } else if (kSideModeA == SideMode::kRight) { - a_1 = tensor_b.at(MatrixCoord(row, k_block)); - b_1 = (compare_op_1(k_block, col)) ? - tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); - } - ComputeType compute_a_1 = ComputeType(a_1); - ComputeType compute_b_1 = ComputeType(b_1); - - // The imaginary parts of the diagonal elements of - // a complex data type are assumed and set to zero - if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { - compute_a_1 = real(compute_a_1); - } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { - compute_b_1 = real(compute_b_1); - } - - accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); - - // A^T x B or B x A^T (without diagonal) - if (kSideModeA == SideMode::kLeft) { - a_2 = (compare_op_2(k_block, row)) ? - (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); - b_2 = tensor_b.at(MatrixCoord(k_block, col)); - if (kBlasMode == BlasMode::kHermitian) - a_2 = conj(a_2); - } else if (kSideModeA == SideMode::kRight) { - a_2 = tensor_b.at(MatrixCoord(row, k_block)); - b_2 = (compare_op_2(col, k_block)) ? - tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); - if (kBlasMode == BlasMode::kHermitian) - b_2 = conj(b_2); - } - - ComputeType compute_a_2 = ComputeType(a_2); - ComputeType compute_b_2 = ComputeType(b_2); - - accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - - ScalarType c = tensor_c.at(coord); - - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * c); - } - } - } - - } // for (col_block) - } // for (row_block) - - tensor_a.add_pointer_offset(batch_stride_A); - tensor_b.add_pointer_offset(batch_stride_B); - tensor_c.add_pointer_offset(batch_stride_C); - tensor_d.add_pointer_offset(batch_stride_D); - - } // for (batch_idx) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, - typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -> -struct SymmComplex; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct SymmComplex { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_symm_complex>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for gaussian multiply-add -template -struct SymmComplex { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, ScalarType beta, - TensorRef tensor_c, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_symm_complex>( - problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h deleted file mode 100644 index d6b85ca1baf65ba811b7c8b3a224ca90bbce1680..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h +++ /dev/null @@ -1,616 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines host-side elementwise operations on TensorView. -*/ - -#pragma once - -// Standard Library includes -#include - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/relatively_equal.h" -#include "cutlass/tensor_view.h" -#include "cutlass/tensor_view_planar_complex.h" - -#include "cutlass/util/distribution.h" -#include "tensor_foreach.h" - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorGreatestErrorFunc { - - // - // Data members - // - - TensorView lhs; - TensorView rhs; - double result; - - /// Ctor - TensorGreatestErrorFunc( - TensorView const &lhs_, - TensorView const &rhs_ - ) : - lhs(lhs_), - rhs(rhs_), - result(0.0) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - Element lhs_ = lhs.at(coord); - Element rhs_ = rhs.at(coord); - - result = std::max(result, std::abs(double(lhs_) - double(rhs_))); - } - - /// Returns true if equal - operator double() const { - return result; - } -}; - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorMREFunc { - - // - // Data members - // - - TensorView lhs; - TensorView rhs; - double sum; - uint64_t count; - static constexpr double epsilon = 1e-6; - - /// Ctor - TensorMREFunc( - TensorView const &lhs_, - TensorView const &rhs_ - ) : - lhs(lhs_), - rhs(rhs_), - sum(0.0), - count(0) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - Element lhs_ = lhs.at(coord); - Element rhs_ = rhs.at(coord); - - sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon)); - ++count; - } - - /// Returns true if equal - operator double() const { - return sum / double(count); - } -}; - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorMSEFunc { - - // - // Data members - // - - TensorView lhs; - TensorView rhs; - double sum; - uint64_t count; - - /// Ctor - TensorMSEFunc( - TensorView const &lhs_, - TensorView const &rhs_ - ) : - lhs(lhs_), - rhs(rhs_), - sum(0.0), - count(0) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - Element lhs_ = lhs.at(coord); - Element rhs_ = rhs.at(coord); - - sum += std::pow((double(lhs_) - double(rhs_)), 2); - ++count; - } - - /// Returns true if equal - operator double() const { - return sum / double(count); - } -}; - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorEqualsFunc { - - // - // Data members - // - - TensorView lhs; - TensorView rhs; - bool result; - - /// Ctor - TensorEqualsFunc(): result(true) { } - - /// Ctor - TensorEqualsFunc( - TensorView const &lhs_, - TensorView const &rhs_ - ) : - lhs(lhs_), rhs(rhs_), result(true) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - Element lhs_ = lhs.at(coord); - Element rhs_ = rhs.at(coord); - - if (lhs_ != rhs_) { - result = false; - } - } - - /// Returns true if equal - operator bool() const { - return result; - } -}; - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorRelativelyEqualsFunc { - - // - // Data members - // - - TensorView lhs; - TensorView rhs; - Element epsilon; - Element nonzero_floor; - bool result; - - /// Ctor - TensorRelativelyEqualsFunc( - TensorView const &lhs_, - TensorView const &rhs_, - Element epsilon_, - Element nonzero_floor_ - ) : - lhs(lhs_), - rhs(rhs_), - epsilon(epsilon_), - nonzero_floor(nonzero_floor_), - result(true) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - Element lhs_ = lhs.at(coord); - Element rhs_ = rhs.at(coord); - - if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { - result = false; - } - } - - /// Returns true if equal - operator bool() const { - return result; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the Mean Squared Error between two tensors. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -double TensorMSE( - TensorView const &lhs, - TensorView const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return -1; - } - - detail::TensorMSEFunc func(lhs, rhs); - TensorForEach( - lhs.extent(), - func - ); - - return double(func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the Mean Relative Error between two tensors. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -double TensorMRE( - TensorView const &lhs, - TensorView const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return -1; - } - - detail::TensorMREFunc func(lhs, rhs); - TensorForEach( - lhs.extent(), - func - ); - - return double(func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns the greatest error between two tensors. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -double TensorGreatestError( - TensorView const &lhs, - TensorView const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return -1; - } - - detail::TensorGreatestErrorFunc func(lhs, rhs); - TensorForEach( - lhs.extent(), - func - ); - - return double(func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if two tensor views are equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorEquals( - TensorView const &lhs, - TensorView const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return false; - } - - detail::TensorEqualsFunc func(lhs, rhs); - TensorForEach( - lhs.extent(), - func - ); - - return bool(func); -} - -/// Returns true if two tensor views are equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorEquals( - TensorViewPlanarComplex const &lhs, - TensorViewPlanarComplex const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return false; - } - - detail::TensorEqualsFunc real_func( - {lhs.data(), lhs.layout(), lhs.extent()}, - {rhs.data(), rhs.layout(), rhs.extent()} - ); - - TensorForEach( - lhs.extent(), - real_func - ); - - if (!bool(real_func)) { - return false; - } - - detail::TensorEqualsFunc imag_func( - {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, - {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} - ); - - TensorForEach( - lhs.extent(), - imag_func - ); - - return bool(imag_func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if two tensor views are relatively equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorRelativelyEquals( - TensorView const &lhs, - TensorView const &rhs, - Element epsilon, - Element nonzero_floor) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return false; - } - - detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); - TensorForEach( - lhs.extent(), - func - ); - - return bool(func); -} - -/// Returns true if two tensor views are relatively equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorRelativelyEquals( - TensorViewPlanarComplex const &lhs, - TensorViewPlanarComplex const &rhs, - Element epsilon, - Element nonzero_floor) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return false; - } - - detail::TensorRelativelyEqualsFunc real_func( - {lhs.data(), lhs.layout(), lhs.extent()}, - {rhs.data(), rhs.layout(), rhs.extent()}, - epsilon, - nonzero_floor - ); - - TensorForEach( - lhs.extent(), - real_func - ); - - if (!bool(real_func)) { - return false; - } - - detail::TensorEqualsFunc imag_func( - {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, - {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, - epsilon, - nonzero_floor - ); - - TensorForEach( - lhs.extent(), - imag_func - ); - - return bool(imag_func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if two tensor views are NOT equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorNotEquals( - TensorView const &lhs, - TensorView const &rhs) { - - // Extents must be identical - if (lhs.extent() != rhs.extent()) { - return true; - } - - detail::TensorEqualsFunc func(lhs, rhs); - TensorForEach( - lhs.extent(), - func - ); - - return !bool(func); -} - -/// Returns true if two tensor views are equal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorNotEquals( - TensorViewPlanarComplex const &lhs, - TensorViewPlanarComplex const &rhs) { - - return !TensorEquals(lhs, rhs); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorContainsFunc { - - // - // Data members - // - - TensorView view; - Element value; - bool contains; - Coord location; - - // - // Methods - // - - /// Ctor - TensorContainsFunc(): contains(false) { } - - /// Ctor - TensorContainsFunc( - TensorView const &view_, - Element value_ - ) : - view(view_), value(value_), contains(false) { } - - /// Visits a coordinate - void operator()(Coord const &coord) { - - if (view.at(coord) == value) { - if (!contains) { - location = coord; - } - contains = true; - } - } - - /// Returns true if equal - operator bool() const { - return contains; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if a value is present in a tensor -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -bool TensorContains( - TensorView const & view, - Element value) { - - detail::TensorContainsFunc func( - view, - value - ); - - TensorForEach( - view.extent(), - func - ); - - return bool(func); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of -/// of the first occurrence. If the value is not contained in the tensor, the second element of the -/// pair is undefined. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -std::pair > TensorFind( - TensorView const & view, - Element value) { - - detail::TensorContainsFunc func( - view, - value - ); - - TensorForEach( - view.extent(), - func - ); - - return std::make_pair(bool(func), func.location); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp deleted file mode 100644 index 27ef969b4ff2b6d8f3a53f3d1a3e5ec3e5203ec3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp +++ /dev/null @@ -1,101 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Provides several functions for filling tensors with data. -*/ - -#pragma once - -// Standard Library includes -#include -#include -#include - -// Cute includes -#include "cute/tensor.hpp" - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Returns true if two tensor views are equal. -template < - typename TensorL, - typename TensorR -> -bool TensorEquals( - TensorL lhs, - TensorR rhs) { - - // Extents must be identical - if (cute::size(lhs) != cute::size(rhs)) { - return false; - } - - for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { - if (lhs(idx) != rhs(idx)) { - return false; - } - } - - return true; -} - -/// Returns true if two tensor views are NOT equal. -template < - typename TensorL, - typename TensorR -> -bool TensorNotEquals( - TensorL lhs, - TensorR rhs) { - - return TensorEquals(lhs, rhs); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h deleted file mode 100644 index d2a43b1295c8ab18c7d649c79b0364b6d3e7c48c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h +++ /dev/null @@ -1,256 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines host-side elementwise operations on TensorView. -*/ - -#pragma once - -// Standard Library includes -#include - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "tensor_foreach.h" - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Helper to convert between types -template < - typename DstElement, - typename SrcElement -> -struct TrivialConvert { - - TrivialConvert() { } - - DstElement operator()(SrcElement src) const { - return DstElement(src); - } -}; - -/// Helper to conditionally copy between tensor views. -template < - typename DstElement, - typename DstLayout, - typename SrcElement, - typename SrcLayout, - typename F -> -struct TensorCopyIf { - - using DstTensorView = TensorView; - using SrcTensorView = TensorView; - - // - // Data members - // - - DstTensorView dst; - SrcTensorView src; - F convert; - - // - // Methods - // - - TensorCopyIf() { } - - TensorCopyIf( - DstTensorView const &dst_, - SrcTensorView const &src_, - F const &convert_): dst(dst_), src(src_), convert(convert_) {} - - /// Copies based on destination and source bounds - void operator()(Coord const &coord) { - if (dst.contains(coord) && src.contains(coord)) { - dst.at(coord) = convert(src.at(coord)); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies elements from one tensor view into another, satisfying bounds of each tensor. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout, /// Source tensor's layout - typename F /// Transformation functor -> -void TensorCopy( - TensorView dst, - TensorView src, - F const &transform) { - - using CopyIf = detail::TensorCopyIf< - DstElement, - DstLayout, - SrcElement, - SrcLayout, - F>; - - CopyIf copy_if(dst, src, transform); - - TensorForEach(dst.extent(), copy_if); -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -/// to avoid out of bounds accesses. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout, /// Source tensor's layout - typename F /// Transformation functor -> -void TensorCopy( - TensorView dst, - TensorRef src, - F const &transform) { - - using CopyIf = detail::TensorCopyIf< - DstElement, - DstLayout, - SrcElement, - SrcLayout, - F>; - - TensorView src_view(src, dst.extent()); - - CopyIf copy_if(dst, src_view, transform); - - TensorForEach(dst.extent(), copy_if); -} - -/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -/// to avoid out of bounds accesses. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout, /// Source tensor's layout - typename F /// Transformation functor -> -void TensorCopy( - TensorRef dst, - TensorView src, - F const &transform) { - - using CopyIf = detail::TensorCopyIf< - DstElement, - DstLayout, - SrcElement, - SrcLayout, - F>; - - TensorView dst_view(dst, src.extent()); - - CopyIf copy_if(dst_view, src, transform); - - TensorForEach(src.extent(), copy_if); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -/// if SrcElement can be converted to DstElement. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout /// Source tensor's layout -> -void TensorCopy( - TensorView dst, - TensorView src) { - - detail::TrivialConvert convert; - - TensorCopy(dst, src, convert); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -/// if SrcElement can be converted to DstElement. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout, /// Source tensor's layout - typename F /// Transformation functor -> -void TensorCopy( - TensorView dst, - TensorRef src) { - - detail::TrivialConvert convert; - - TensorCopy(dst, src, convert); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -/// if SrcElement can be converted to DstElement. -template < - typename DstElement, /// Destination tensor's element type - typename DstLayout, /// Destination tensor's layout - typename SrcElement, /// Source tensor's element type - typename SrcLayout /// Source tensor's layout -> -void TensorCopy( - TensorRef dst, - TensorView src) { - - detail::TrivialConvert convert; - - TensorCopy(dst, src, convert); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h deleted file mode 100644 index 5470df29358799f6d5e6628e8722f0e3dc05485f..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h +++ /dev/null @@ -1,341 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Defines host-side elementwise operations on TensorView. -*/ - -#pragma once - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/functional.h" - -#include "tensor_foreach.h" - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to apply a binary operator in place -template < - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB, - typename ElementD, - typename LayoutD, - typename BinaryFunc> -struct TensorFuncBinaryOp { - - // - // Data members - // - - /// View of left-hand-side tensor - TensorView view_d; - TensorRef view_a; - TensorRef view_b; - BinaryFunc func; - - // - // Methods - // - - /// Constructor - TensorFuncBinaryOp() { } - - /// Constructor - TensorFuncBinaryOp( - TensorView const & view_d_, - TensorRef const & view_a_, - TensorRef const & view_b_, - BinaryFunc func = BinaryFunc() - ): - view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } - - /// Equality check - void operator()(Coord const &coord) const { - view_d.at(coord) = func( - ElementD(view_a.at(coord)), - ElementD(view_b.at(coord)) - ); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Adds two tensors and stores in the destination tensor: d = a + b -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorAdd( - TensorView d, ///< destination tensor view - TensorRef a, ///< A tensor reference - TensorRef b ///< B tensor reference -) { - - detail::TensorFuncBinaryOp< - ElementD, - LayoutD, - ElementA, - LayoutA, - ElementB, - LayoutB, - cutlass::plus - > func(d, a, b); - - TensorForEach( - d.extent(), - func); -} - -/// Adds a tensor in place: d = d .+ a -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA -> -void TensorAdd( - TensorView d, ///< destination tensor view - TensorRef a ///< A tensor reference -) { - TensorAdd(d, d, a); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Subtracts two tensors and stores in the destination tensor: d = a - b -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorSub( - TensorView d, ///< destination tensor view - TensorRef a, ///< A tensor reference - TensorRef b ///< B tensor reference - ) { - - detail::TensorFuncBinaryOp< - ElementD, - LayoutD, - ElementA, - LayoutA, - ElementB, - LayoutB, - cutlass::minus - > func(d, a, b); - - TensorForEach( - d.extent(), - func); -} - -/// Subtracts two tensors in place: d = d .- a -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorSub( - TensorView d, ///< destination tensor view - TensorRef a ///< A tensor reference - ) { - - TensorSub(d, d, a); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Multiplies two tensors and stores in the destination tensor: d = a .* b -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorMul( - TensorView d, ///< destination tensor view - TensorRef a, ///< A tensor reference - TensorRef b ///< B tensor reference -) { - - detail::TensorFuncBinaryOp< - ElementD, - LayoutD, - ElementA, - LayoutA, - ElementB, - LayoutB, - cutlass::multiplies - > func(d, a, b); - - TensorForEach( - d.extent(), - func); -} - -/// Multiplies tensors in place: d = d .* a -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA -> -void TensorMul( - TensorView d, ///< destination tensor view - TensorRef a ///< A tensor reference -) { - TensorMul(d, d, a); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Divides two tensors and stores in the destination tensor: d = a ./ b -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorDiv( - TensorView d, ///< destination tensor view - TensorRef a, ///< A tensor reference - TensorRef b ///< B tensor reference -) { - - detail::TensorFuncBinaryOp< - ElementD, - LayoutD, - ElementA, - LayoutA, - ElementB, - LayoutB, - cutlass::divides - > func(d, a, b); - - TensorForEach( - d.extent(), - func); -} - -/// Divides tensors in place: d = d ./ a -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA -> -void TensorDiv( - TensorView d, ///< destination tensor view - TensorRef a ///< A tensor reference -) { - TensorDiv(d, d, a); -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Divides two tensors and stores in the destination tensor: d = a ./ b -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA, - typename ElementB, - typename LayoutB -> -void TensorModulus( - TensorView d, ///< destination tensor view - TensorRef a, ///< A tensor reference - TensorRef b ///< B tensor reference -) { - - detail::TensorFuncBinaryOp< - ElementD, - LayoutD, - ElementA, - LayoutA, - ElementB, - LayoutB, - cutlass::divides - > func(d, a, b); - - TensorForEach( - d.extent(), - func); -} - -/// Divides tensors in place: d = d ./ a -template < - typename ElementD, - typename LayoutD, - typename ElementA, - typename LayoutA -> -void TensorModulus( - TensorView d, ///< destination tensor view - TensorRef a ///< A tensor reference -) { - TensorDiv(d, d, a); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h deleted file mode 100644 index 645902f7dd7b62bc98a479e4956dfb4b437d46a7..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ /dev/null @@ -1,1718 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Provides several functions for filling tensors with data. -*/ - -#pragma once - -// Standard Library includes -#include -#include -#include -#include -#include - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/subbyte_reference.h" -#include "cutlass/tensor_view.h" -#include "cutlass/tensor_view_planar_complex.h" -#include "cutlass/blas3.h" - -#include "cutlass/util/distribution.h" -#include "tensor_foreach.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - Element value; - - // - // Methods - // - - TensorFillFunc( - TensorView const &view_ = TensorView(), - Element value_ = Element(0) - ): view(view_), value(value_) { } - - void operator()(Coord const & coord) const { - view.at(coord) = value; - } -}; - -/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method -struct BoxMullerFunc { - - BoxMullerFunc() {} - - void operator()( - double* rnd, ///< Size-2 vector to be filled with random values - double mean = 0, ///< Mean of the Gaussian distribution - double stddev = 1, ///< Standard deviation of the Gaussian distribution - double pi = std::acos(-1)) const { - - double u1 = double(std::rand()) / double(RAND_MAX); - double u2 = double(std::rand()) / double(RAND_MAX); - rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); - rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); - rnd[0] = mean + stddev * rnd[0]; - rnd[1] = mean + stddev * rnd[1]; - } -}; -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with a uniform value -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFill( - TensorView dst, ///< destination tensor - Element val = Element(0)) { ///< value to uniformly fill it with - - detail::TensorFillFunc func(dst, val); - - TensorForEach( - dst.extent(), - func - ); -} - -/// Fills a tensor with a uniform value -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFill( - TensorViewPlanarComplex dst, ///< destination tensor - cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with - - TensorFill(dst.view_real(), val.real()); - TensorFill(dst.view_imag(), val.imag()); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct RandomGaussianFunc { - - uint64_t seed; - double mean; - double stddev; - int int_scale; - double pi; - double pnz; - bool exclude_zero; - - // - // Methods - // - RandomGaussianFunc( - uint64_t seed_ = 0, - double mean_ = 0, - double stddev_ = 1, - int int_scale_ = -1, - double pnz_ = 1.0, - bool exclude_zero_ = false - ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { - std::srand((unsigned)seed); - } - - /// Compute random value and update RNG state - Element operator()() const { - - // Box-Muller transform to generate random numbers with Normal distribution - double u1 = double(std::rand()) / double(RAND_MAX); - double u2 = double(std::rand()) / double(RAND_MAX); - - // Compute Gaussian random value - double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); - rnd = mean + stddev * rnd; - - // Scale and convert final result - Element result; - - // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian - std::random_device rnd_device; - std::mt19937 bernoulli_rnd(rnd_device()); - std::bernoulli_distribution bernoulli_dist(pnz); - bool bernoulli_result = bernoulli_dist(bernoulli_rnd); - - // Sample from the Gaussian distribution for a nonzero element - if (bernoulli_result) { - if (int_scale >= 0) { - rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); - result = static_cast(rnd); - } - else { - result = static_cast(rnd); - } - } - else { - result = static_cast(0); - } - - // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros - if (exclude_zero && result == Element(0)) { - if (rnd > 0) { - rnd += 1; - } else { - rnd -= 1; - } - result = Element(rnd); - } - - return result; - } -}; - -/// Partial specialization for initializing a complex value. -template -struct RandomGaussianFunc > { - - uint64_t seed; - double mean; - double stddev; - int int_scale; - double pi; - double pnz; - bool exclude_zero; - - // - // Methods - // - RandomGaussianFunc( - uint64_t seed_ = 0, - double mean_ = 0, - double stddev_ = 1, - int int_scale_ = -1, - double pnz_ = 1.0, - bool exclude_zero_ = false - ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { - std::srand((unsigned)seed); - } - - /// Compute random value and update RNG state - complex operator()() const { - - Element reals[2]; - - double rnd[2]; - detail::BoxMullerFunc func; - func(rnd, mean, stddev, pi); - - // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian - std::random_device rnd_device; - std::mt19937 bernoulli_rnd(rnd_device()); - std::bernoulli_distribution bernoulli_dist(pnz); - bool bernoulli_result = bernoulli_dist(bernoulli_rnd); - - // Sample from the Gaussian distribution for a nonzero element - if (bernoulli_result) { - if (int_scale >= 0) { - rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); - rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); - reals[0] = from_real(rnd[0] / double(1 << int_scale)); - reals[1] = from_real(rnd[1] / double(1 << int_scale)); - } - else { - reals[0] = from_real(rnd[0]); - reals[1] = from_real(rnd[1]); - } - } - else { - reals[0] = from_real(0); - reals[1] = from_real(0); - } - - // Note that this will invalidate the above else statement because it unsets zero elements - if (exclude_zero && - reals[0] == from_real(0.0) && - reals[1] == from_real(0.0)) { - - if (rnd[0] > 0.0) { - rnd[0] += 1.0; - } else { - rnd[0] -= 1.0; - } - reals[0] = from_real(rnd[0]); - } - - return complex(reals[0], reals[1]); - } -}; - -/// Partial specialization for initializing a complex value. -template -struct RandomGaussianFunc > { - - uint64_t seed; - double mean; - double stddev; - int int_scale; - double pi; - double pnz; - bool exclude_zero; - - // - // Methods - // - RandomGaussianFunc( - uint64_t seed_ = 0, - double mean_ = 0, - double stddev_ = 1, - int int_scale_ = -1, - double pnz_ = 1.0, - bool exclude_zero_ = false - ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { - std::srand((unsigned)seed); - } - - /// Compute random value and update RNG state - Quaternion operator()() const { - - Element reals[4]; - - double rnd1[2]; - double rnd2[2]; - detail::BoxMullerFunc func; - func(rnd1, mean, stddev, pi); - func(rnd2, mean, stddev, pi); - - // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian - std::random_device rnd_device; - std::mt19937 bernoulli_rnd(rnd_device()); - std::bernoulli_distribution bernoulli_dist(pnz); - bool bernoulli_result = bernoulli_dist(bernoulli_rnd); - - // Sample from the Gaussian distribution for a nonzero element - if (bernoulli_result) { - if (int_scale >= 0) { - rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); - rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); - rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); - rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); - - reals[0] = from_real(rnd1[0] / double(1 << int_scale)); - reals[1] = from_real(rnd1[1] / double(1 << int_scale)); - reals[2] = from_real(rnd2[0] / double(1 << int_scale)); - reals[3] = from_real(rnd2[1] / double(1 << int_scale)); - } - else { - reals[0] = from_real(rnd1[0]); - reals[1] = from_real(rnd1[1]); - reals[2] = from_real(rnd2[0]); - reals[3] = from_real(rnd2[1]); - } - } - else { - reals[0] = from_real(0); - reals[1] = from_real(0); - reals[2] = from_real(0); - reals[3] = from_real(0); - } - - // Note that this will invalidate the above else statement because it unsets zero elements - if (exclude_zero && - reals[0] == from_real(0) && - reals[1] == from_real(0) && - reals[2] == from_real(0) && - reals[3] == from_real(0)) { - - if (rnd1[0] > 0.0) { - rnd1[0] += 1.0; - } else { - rnd1[0] -= 1.0; - } - reals[0] = from_real(rnd1[0]); - } - - return Quaternion(reals[0], reals[1], reals[2], reals[3]); - } -}; - -/// Computes a random Gaussian distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillGaussianFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomGaussianFunc func; - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - TensorFillGaussianFunc( - TensorView view_ = TensorView(), - RandomGaussianFunc func_ = RandomGaussianFunc() - ): - view(view_), func(func_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) const { - view.at(coord) = func(); - } -}; - -/// Computes a random Gaussian distribution for a rank-2 tensor -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillSymmetricGaussianFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomGaussianFunc func; - cutlass::FillMode fill_mode; - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - TensorFillSymmetricGaussianFunc( - TensorView view_ = TensorView(), - RandomGaussianFunc func_ = RandomGaussianFunc(), - cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid - ): - view(view_), func(func_), fill_mode(fill_mode_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) const { - // Fill half of matrix based on FillMode - if (Layout::kRank == 2 && - fill_mode == cutlass::FillMode::kLower && - coord[0] >= coord[1]) { - view.at(coord) = func(); - } else if (Layout::kRank == 2 && - fill_mode == cutlass::FillMode::kUpper && - coord[0] <= coord[1]) { - view.at(coord) = func(); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a Gaussian distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomGaussian( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of - /// data. - bool exclude_zero = false) { ///< Exclude zeros from tensor init. - - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); - - detail::TensorFillGaussianFunc func( - dst, - random_func - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/// Fills a tensor with random values with a Gaussian distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomGaussian( - TensorViewPlanarComplex dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of - /// data. - bool exclude_zero = false) { ///< Exclude zeros from tensor init. - - TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); - TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillSymmetricRandomGaussian( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of - /// data. - - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); - - detail::TensorFillSymmetricGaussianFunc func( - dst, - random_func, - fill_mode - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values of a Gaussian distribution. -template < - typename Element ///< Element type -> -void BlockFillRandomGaussian( - Element *ptr, ///< destination buffer - size_t capacity, ///< number of elements - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1, ///< If non-negative, specifies number of fractional bits that - double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of - /// data. - - - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); - - for (size_t i = 0; i < capacity; ++i) { - ReferenceFactory::get(ptr, i) = random_func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct RandomUniformFunc { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - double pnan; -private: - using engine_type = std::mt19937; -public: - engine_type bernoulli_rnd; - std::bernoulli_distribution bernoulli_dist; - - bool exclude_zero; - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1, - double pnan_ = 0, - bool exclude_zero_ = false - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) - , bernoulli_rnd{static_cast(seed_)} - , bernoulli_dist(pnan_) - , exclude_zero(exclude_zero_) - { - std::srand((unsigned)seed); - - // Handle cases where min = 0 or max = 0 for excluding zeros - if (exclude_zero) { - min = (min == 0.0) ? min + 1: min; - range = (max == 0.0) ? range - 1: range; - } - } - - - /// Compute random value and update RNG state - Element operator()() { - - // Sample from NaN distribution. - if constexpr (std::numeric_limits::has_quiet_NaN) { - if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { - return Element(NAN); - } - } - - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - Element result; - if (int_scale >= 0) { - rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); - result = static_cast(Real(rnd)); - } - else { - result = static_cast(Real(rnd)); - } - - if (exclude_zero && result == Element(0)) { - if (rnd > 0.0) { - rnd = std::min(min + range, rnd + 1.0); - } else { - rnd = std::max(min, rnd - 1.0); - } - result = static_cast(Real(rnd)); - } - - return result; - } -}; - -/// Partial specialization for initializing a complex value. -template -struct RandomUniformFunc > { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - double pnan; -private: - using engine_type = std::mt19937; -public: - engine_type bernoulli_rnd; - std::bernoulli_distribution bernoulli_dist; - - bool exclude_zero; - - // - // Methods - // - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1, - double pnan_ = 0, - bool exclude_zero_ = false - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) - , bernoulli_rnd{static_cast(seed_)} - , bernoulli_dist(pnan_) - , exclude_zero(exclude_zero_) { - std::srand((unsigned)seed); - - // Handle cases where min = 0 or max = 0 for excluding zeros - if (exclude_zero) { - min = (min == 0.0) ? min + 1: min; - range = (max == 0.0) ? range - 1: range; - } - } - - - /// Compute random value and update RNG state - complex operator()() { - - // Sample from NaN distribution. - if constexpr (std::numeric_limits::has_quiet_NaN) { - if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { - return Element(NAN); - } - } - - Element reals[2]; - - for (int i = 0; i < 2; ++i) { - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - - if (int_scale >= 0) { - rnd = double(std::llround(rnd * double(1 << int_scale))); - reals[i] = from_real(Real(rnd / double(1 << int_scale))); - } - else { - reals[i] = from_real(Real(rnd)); - } - - if (exclude_zero && - i == 0 && - reals[0] == from_real(0.0)) { - - if (rnd > 0.0) { - rnd = std::min(min + range, rnd + 1.0); - } else { - rnd = std::max(min, rnd - 1.0); - } - reals[0] = from_real(Real(rnd)); - } - - } - - return complex(reals[0], reals[1]); - } -}; - -/// Partial specialization for initializing a Quaternion value. -template -struct RandomUniformFunc > { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - double pnan; -private: - using engine_type = std::mt19937; -public: - engine_type bernoulli_rnd; - std::bernoulli_distribution bernoulli_dist; - - // - // Methods - // - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1, - double pnan_ = 0 - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), - bernoulli_rnd{static_cast(seed_)}, - bernoulli_dist(pnan_) - { - std::srand((unsigned)seed); - } - - - /// Compute random value and update RNG state - Quaternion operator()() { - - // Sample from NaN distribution. - if constexpr (std::numeric_limits::has_quiet_NaN) { - if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { - return Element(NAN); - } - } - - Element reals[4]; - - for (int i = 0; i < 4; ++i) { - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - - if (int_scale >= 0) { - rnd = double(std::llround(rnd * double(1 << int_scale))); - reals[i] = from_real(Real(rnd / double(1 << int_scale))); - } - else { - reals[i] = from_real(Real(rnd)); - } - } - - return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); - } -}; - -/// Computes a random uniform distribution -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillRandomUniformFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomUniformFunc func; - - // - // Methods - // - - /// Construction of uniform RNG functor. - TensorFillRandomUniformFunc( - TensorView view_ = TensorView(), - RandomUniformFunc func_ = RandomUniformFunc() - ): - view(view_), func(func_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) { - - view.at(coord) = func(); - } -}; - -/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillSymmetricRandomUniformFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomUniformFunc func; - cutlass::FillMode fill_mode; - - // - // Methods - // - - /// Construction of uniform RNG functor. - TensorFillSymmetricRandomUniformFunc( - TensorView view_ = TensorView(), - RandomUniformFunc func_ = RandomUniformFunc(), - cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid - ): - view(view_), func(func_), fill_mode(fill_mode_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) { - // Fill half of matrix based on FillMode - if (Layout::kRank == 2 && - fill_mode == cutlass::FillMode::kLower && - coord[0] >= coord[1]) { - view.at(coord) = func(); - } else if (Layout::kRank == 2 && - fill_mode == cutlass::FillMode::kUpper && - coord[0] <= coord[1]) { - view.at(coord) = func(); - } - } -}; - -/// Computes a random Uniform distribution and pads diagonal with zeros -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillPadDiagonalRandomUniformFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomUniformFunc func; - cutlass::FillMode fill_mode; - int alignment; - - // - // Methods - // - - /// Construction of uniform RNG functor. - TensorFillPadDiagonalRandomUniformFunc( - TensorView view_ = TensorView(), - RandomUniformFunc func_ = RandomUniformFunc(), - cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, - int alignment_ = 1 - ): - view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) { - // Fill half of matrix based on FillMode - if (Layout::kRank == 2 && - (fill_mode == cutlass::FillMode::kLower) && - (coord[0] >= coord[1]) || - ((coord[1] - coord[0]) >= alignment)) { - view.at(coord) = func(); - } else if (Layout::kRank == 2 && - fill_mode == cutlass::FillMode::kUpper && - (coord[0] <= coord[1]) || - ((coord[0] - coord[1]) >= alignment)) { - view.at(coord) = func(); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values of a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomUniform( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - double pnan = 0, ///< Percentage of NaN elements. - bool exclude_zero = false) { ///< Exclude zero from tensor init - detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); - - detail::TensorFillRandomUniformFunc func( - dst, - random_func - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/// Fills a tensor with random values of a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomUniform( - TensorViewPlanarComplex dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - double pnan = 0, ///< Percentage of NaN elements. - bool exclude_zero = false) { ///< Exclude zero from tensor init - - TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); - TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); -} - - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomUniform( - TensorView, Layout> dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - detail::RandomUniformFunc> random_func(seed, max, min, bits); - - detail::TensorFillRandomUniformFunc, Layout> func( - dst, - random_func - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillSymmetricRandomUniform( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - - detail::RandomUniformFunc random_func(seed, max, min, bits); - - detail::TensorFillSymmetricRandomUniformFunc func( - dst, - random_func, - fill_mode - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillPadDiagonalRandomUniform( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - int alignment = 1 -) { - - detail::RandomUniformFunc random_func(seed, max, min, bits); - - detail::TensorFillPadDiagonalRandomUniformFunc func( - dst, - random_func, - fill_mode, - alignment - ); - - TensorForEach( - dst.extent(), - func - ); -} -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with a uniform value -template < - typename Element ///< Element type -> -void BlockFill( - Element *ptr, - size_t capacity, - Element val - ) { - for (size_t i = 0; i < capacity; ++i) { - ReferenceFactory::get(ptr, i) = val; - } -} - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element ///< Element type -> -void BlockFillRandomUniform( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1, ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - double pnan = 0) { ///< Percentage of NaN elements. - detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); - - for (size_t i = 0; i < capacity; ++i) { - ReferenceFactory::get(ptr, i) = random_func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillDiagonalFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - Element diag; - Element other; - - // - // Methods - // - - TensorFillDiagonalFunc( - TensorView const &view_ = TensorView(), - Element diag_ = Element(1), - Element other_ = Element(0) - ): - view(view_), diag(diag_), other(other_) { } - - void operator()(Coord const & coord) const { - bool is_diag = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[i - 1]) { - is_diag = false; - break; - } - } - - view.at(coord) = (is_diag ? diag : other); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor everywhere with a unique value for its diagonal. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillDiagonal( - TensorView dst, ///< destination tensor - Element diag = Element(1), ///< value to write in the diagonal - Element other = Element(0)) { ///< value to write off the diagonal - - detail::TensorFillDiagonalFunc func( - dst, - diag, - other - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillIdentity( - TensorView dst) { ///< destination tensor - - TensorFillDiagonal(dst, Element(1), Element(0)); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorUpdateDiagonal( - TensorView dst, ///< destination tensor - Element val = Element(1)) { - - typename Layout::Index extent = dst.extent().min(); - - for (typename Layout::Index i = 0; i < extent; ++i) { - Coord coord(i); - dst.at(coord) = val; - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorUpdateOffDiagonalFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - Element other; - - // - // Methods - // - - TensorUpdateOffDiagonalFunc( - TensorView const &view_ = TensorView(), - Element other_ = Element(0) - ): - view(view_), other(other_) { } - - void operator()(Coord const & coord) const { - bool is_diag = true; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - if (coord[i] != coord[i - 1]) { - is_diag = false; - break; - } - } - - if (!is_diag) { - view.at(coord) = other; - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorUpdateOffDiagonal( - TensorView dst, ///< destination tensor - Element other = Element(1)) { - - detail::TensorUpdateOffDiagonalFunc func( - dst, - other - ); - - TensorForEach( - dst.extent(), - func - ); -} - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillLinearFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - Array v; - Element s; - - // - // Methods - // - - TensorFillLinearFunc() { } - - /// Constructs functor - TensorFillLinearFunc( - TensorView const &view_, - Array const & v_, - Element s_ = Element(0) - ): - view(view_), v(v_), s(s_) { } - - /// Updates the tensor - void operator()(Coord const & coord) const { - - Element sum(s); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Layout::kRank; ++i) { - sum += Element(coord[i]) * v[i]; - } - - view.at(coord) = sum; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills tensor with a linear combination of its coordinate and another vector -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillLinear( - TensorView dst, ///< destination tensor - Array const & v, - Element s = Element(0)) { - - detail::TensorFillLinearFunc func( - dst, - v, - s - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills tensor with a linear combination of its coordinate and another vector -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillSequential( - TensorView dst, ///< destination tensor - Element s = Element(0)) { - - Array stride; - - stride[0] = Element(1); - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < Layout::kRank; ++i) { - stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); - } - - TensorFillLinear(dst, stride, s); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values from a distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandom( - TensorView view, ///< destination tensor - uint64_t seed, - Distribution dist, - bool exclude_zero = false ///< If true, excludes 0. - /// Note that setting this flag will result in more 1's, - /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. -) { - - using Real = typename RealType::Type; - - if (dist.kind == Distribution::Gaussian) { - TensorFillRandomGaussian( - view, - seed, - dist.gaussian.mean, - dist.gaussian.stddev, - dist.int_scale, - dist.gaussian.pnz, - exclude_zero); - } else if (dist.kind == Distribution::Uniform) { - TensorFillRandomUniform( - view, - seed, - dist.uniform.max, - dist.uniform.min, - dist.int_scale, - dist.uniform.pnan, - exclude_zero); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillSequential( - Element *ptr, - int64_t capacity, - Element v = Element(1), - Element s = Element(0)) { - int i = 0; - - while (i < capacity) { - cutlass::ReferenceFactory::value < - 8)>::get(ptr, i) = s; - - s = Element(s + v); - ++i; - } -} - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillSequentialModN( - Element *ptr, - int64_t capacity, - int64_t mod, - int64_t v = int64_t(1), - int64_t s = int64_t(0)) { - int i = 0; - - while (i < capacity) { - cutlass::ReferenceFactory::value < - 8)>::get(ptr, i) = Element(s); - - s = int64_t(s + v) % mod; - ++i; - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillRandom( - Element *ptr, - size_t capacity, - uint64_t seed, - Distribution dist) { - - if (dist.kind == Distribution::Gaussian) { - BlockFillRandomGaussian( - ptr, - capacity, - seed, - dist.gaussian.mean, - dist.gaussian.stddev, - dist.int_scale, - dist.gaussian.pnz); - } - else if (dist.kind == Distribution::Uniform) { - BlockFillRandomUniform( - ptr, - capacity, - seed, - dist.uniform.max, - dist.uniform.min, - dist.int_scale, - dist.uniform.pnan); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct RandomSparseMetaFunc { - - uint64_t seed; - int range; - int MetaSizeInBits; - - // - // Methods - // - - RandomSparseMetaFunc( - uint64_t seed_ = 0, - int MetaSizeInBits_ = 2 - ): - seed(seed_), MetaSizeInBits(MetaSizeInBits_) { - std::srand((unsigned)seed); - if (MetaSizeInBits_ == 2) { - range = 6; - } - else if (MetaSizeInBits_ == 4) { - range = 2; - } - else { - throw std::invalid_argument("Invalid MetaSizeInBits"); - } - } - - /// Compute random value and update RNG state - Element operator()() const { - Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; - Element TwoToOneMeta[2] = {0x4, 0xe}; - - Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; - - Element result = 0x0; - - for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { - int rnd = std::rand() % range; - Element meta = MetaArray[rnd]; - - result = (Element)(result | ((Element)(meta << (i * 4)))); - } - - return result; - } -}; - -/// Computes a random sparse meta -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -struct TensorFillRandomSparseMetaFunc { - - using TensorView = TensorView; - - // - // Data members - // - - TensorView view; - RandomSparseMetaFunc func; - - // - // Methods - // - - /// Construction of Gaussian RNG functor. - TensorFillRandomSparseMetaFunc( - TensorView view_ = TensorView(), - RandomSparseMetaFunc func_ = RandomSparseMetaFunc() - ): - view(view_), func(func_) { - - } - - /// Compute random value and update RNG state - void operator()(Coord const &coord) const { - - view.at(coord) = func(); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomSparseMeta( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - int MetaSizeInBits) { ///< 2 bit or 4 bit - - detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); - - detail::TensorFillRandomSparseMetaFunc func( - dst, - random_func - ); - - TensorForEach( - dst.extent(), - func - ); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template < - typename Element ///< Element type -> -void BlockFillRandomSparseMeta( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - int MetaSizeInBits) { ///< 2 bit or 4bit - - detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); - - for (size_t i = 0; i < capacity; ++i) { - ptr[i] = random_func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a ell block index matrix with random values with a uniform random distribution. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorFillRandomEllIdx( - TensorView dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - int rows, int ell_cols, int cols) { ///< dimension of the matrix - - std::srand((unsigned)seed); - - for (int i = 0; i < rows; ++i) { - int col_idx = std::rand() % cols; - - for (int j = 0; j < ell_cols; ++j) { - dst.at({i, j}) = col_idx; - - if (col_idx != -1) { - if (col_idx == (cols - 1)) { - col_idx = -1; - } else { - col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; - } - } - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies a diagonal in from host memory without modifying off-diagonal elements. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorCopyDiagonalIn( - TensorView dst, ///< destination tensor - Element const *ptr) { ///< dense buffer of elements - - typename Layout::Index extent = dst.extent().min(); - - for (typename Layout::Index i = 0; i < extent; ++i) { - Coord coord(i); - dst.at(coord) = ReferenceFactory::get(ptr, i); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Copies the diagonal of a tensor into a dense buffer in host memory. -template < - typename Element, ///< Element type - typename Layout> ///< Layout function -void TensorCopyDiagonalOut( - Element *ptr, ///< dense buffer of elements - TensorView src) { ///< source tensor - - typename Layout::Index extent = src.extent().min(); - - for (typename Layout::Index i = 0; i < extent; ++i) { - Coord coord(i); - ReferenceFactory::get(ptr, i) = src.at(coord); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp deleted file mode 100644 index 1b3df239a1b9d69fc12e7ec4be2de6f87b3a0e3c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp +++ /dev/null @@ -1,432 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Provides several functions for filling tensors with data. -*/ - -#pragma once - -// Standard Library includes -#include -#include -#include - -// Cute includes -#include "cute/tensor.hpp" - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Uniform and procedural tensor fills -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with a scalar element -template -void TensorFill(Tensor dst, typename Tensor::value_type element) { - - for (int64_t idx = 0; idx < cute::size(dst); ++idx) { - dst(idx) = element; - } -} - -/// Fills a tensor with the contents of its layout -template -void TensorFillSequential(Tensor dst) { - - auto layout = dst.layout(); - - for (int64_t idx = 0; idx < cute::size(dst); ++idx) { - dst(idx) = layout(idx); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Random uniform values -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct RandomUniformFunc { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - // - // Methods - // - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1 - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { - std::srand((unsigned)seed); - } - - - /// Compute random value and update RNG state - Element operator()() const { - - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - Element result; - - if (int_scale >= 0) { - rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); - result = static_cast(Real(rnd)); - } - else { - result = static_cast(Real(rnd)); - } - - return result; - } -}; - -/// Partial specialization for initializing a complex value. -template -struct RandomUniformFunc > { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - // - // Methods - // - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1 - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { - std::srand((unsigned)seed); - } - - - /// Compute random value and update RNG state - complex operator()() const { - - Element reals[2]; - - for (int i = 0; i < 2; ++i) { - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - - if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); - reals[i] = from_real(Real(rnd / double(1 << int_scale))); - } - else { - reals[i] = from_real(Real(rnd)); - } - } - - return complex(reals[0], reals[1]); - } -}; - -/// Partial specialization for initializing a Quaternion value. -template -struct RandomUniformFunc > { - - using Real = typename RealType::Type; - - uint64_t seed; - double range; - double min; - int int_scale; - - // - // Methods - // - - RandomUniformFunc( - uint64_t seed_ = 0, - double max = 1, - double min_ = 0, - int int_scale_ = -1 - ): - seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { - std::srand((unsigned)seed); - } - - - /// Compute random value and update RNG state - Quaternion operator()() const { - - Element reals[4]; - - for (int i = 0; i < 4; ++i) { - double rnd = double(std::rand()) / double(RAND_MAX); - - rnd = min + range * rnd; - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - - if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); - reals[i] = from_real(Real(rnd / double(1 << int_scale))); - } - else { - reals[i] = from_real(Real(rnd)); - } - } - - return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a uniform random distribution. -template ///< Tensor object -void TensorFillRandomUniform( - Tensor dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - - detail::RandomUniformFunc random_func(seed, max, min, bits); - - for (int64_t idx = 0; idx < cute::size(dst); ++idx) { - dst(idx) = random_func(); - } -} - -/// Fills a block with random values with a uniform random distribution. -template < - typename Element ///< Element type -> -void BlockFillRandomUniform( - Element *ptr, - size_t capacity, - uint64_t seed, ///< seed for RNG - double max = 1, ///< upper bound of distribution - double min = 0, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - detail::RandomUniformFunc random_func(seed, max, min, bits); - - for (size_t i = 0; i < capacity; ++i) { - ptr[i] = random_func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Random Gaussian -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -struct RandomGaussianFunc { - - uint64_t seed; - double mean; - double stddev; - int int_scale; - double pi; - - // - // Methods - // - RandomGaussianFunc( - uint64_t seed_ = 0, - double mean_ = 0, - double stddev_ = 1, - int int_scale_ = -1 - ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { - std::srand((unsigned)seed); - } - - /// Compute random value and update RNG state - Element operator()() const { - - // Box-Muller transform to generate random numbers with Normal distribution - double u1 = double(std::rand()) / double(RAND_MAX); - double u2 = double(std::rand()) / double(RAND_MAX); - - // Compute Gaussian random value - double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); - rnd = mean + stddev * rnd; - - // Scale and convert final result - Element result; - - if (int_scale >= 0) { - rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); - result = static_cast(rnd); - } - else { - result = static_cast(rnd); - } - - return result; - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a tensor with random values with a Gaussian distribution. -template < - typename Tensor -> -void TensorFillRandomGaussian( - Tensor dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); - - for (int64_t idx = 0; idx < cute::size(dst); ++idx) { - dst(idx) = random_func(); - } -} - -/// Fills a block with random values with a Gaussian distribution. -template < - typename Element ///< Element type -> -void BlockFillRandomGaussian( - Element *ptr, ///< destination buffer - size_t capacity, ///< number of elements - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); - - for (size_t i = 0; i < capacity; ++i) { - ptr[i] = random_func(); - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillSequential( - Element *ptr, - int64_t capacity, - Element v = Element(1), - Element s = Element(0)) { - int i = 0; - - while (i < capacity) { - - ptr[i] = Element(s + v); - ++i; - } -} - -/// Fills a block of data with sequential elements -template < - typename Element -> -void BlockFillSequentialModN( - Element *ptr, - int64_t capacity, - int64_t mod, - int64_t v = int64_t(1), - int64_t s = int64_t(0)) { - int i = 0; - - while (i < capacity) { - - ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); - ++i; - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h deleted file mode 100644 index bcb1af995805e3fbcbdbf398ce7191ea2f0dbe8d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h +++ /dev/null @@ -1,134 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include -#include "cutlass/cutlass.h" - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines several helpers -namespace detail { - -/// Helper to perform for-each operation -template -struct TensorForEachHelper { - - /// Index of the active rank - static int const kActiveRank = Rank - RankRemaining - 1; - - /// Constructor for general rank - TensorForEachHelper( - Func &func, - Coord const &extent, - Coord &coord) { - - for (int i = 0; i < extent.at(kActiveRank); ++i) { - coord[kActiveRank] = i; - TensorForEachHelper(func, extent, coord); - } - } -}; - -/// Helper to perform for-each operation -template -struct TensorForEachHelper { - - /// Index of the active rank - static int const kActiveRank = Rank - 1; - - /// Constructor for fastest changing rank - TensorForEachHelper( - Func &func, - Coord const &extent, - Coord &coord) { - - for (int i = 0; i < extent.at(kActiveRank); ++i) { - coord[kActiveRank] = i; - func(coord); - } - } -}; - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Iterates over the index space of a tensor -template < - typename Func, ///< function applied to each point in a tensor's index space - int Rank> ///< rank of index space -void TensorForEach(Coord extent, Func & func) { - Coord coord; - detail::TensorForEachHelper(func, extent, coord); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Iterates over the index space of a tensor and calls a C++ lambda -template < - typename Func, ///< function applied to each point in a tensor's index space - int Rank> ///< rank of index space -void TensorForEachLambda(Coord extent, Func func) { - Coord coord; - detail::TensorForEachHelper(func, extent, coord); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockForEach { - - /// Constructor performs the operation. - BlockForEach( - Element *ptr, - size_t capacity, - typename Func::Params params = typename Func::Params()) { - - Func func(params); - - for (size_t index = 0; index < capacity; ++index) { - ptr[index] = func(); - } - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h deleted file mode 100644 index d44dda1f5472f13b7212f7e2e4020e254ff92f88..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h +++ /dev/null @@ -1,42 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - - -#include "cutlass/cutlass.h" - -// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. - -#include "cutlass/util/reference/host/tensor_reduce.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h deleted file mode 100644 index 887c568059a90f749fc0ac75dd211ce77085a5a9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h +++ /dev/null @@ -1,203 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/util/reference/detail/linear_to_coordinate.h" -#include "cutlass/core_io.h" - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view, - ComputeType identity, - ReduceOp reduce, - TransformOp transform -) { - - for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { - typename Layout::TensorCoord coord; - cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); - - if (view.contains(coord)) { - Element x = view.at(coord); - identity = reduce(identity, transform(x)); - } - } - - return identity; -} - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename Element, - typename Layout, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorView view_A, - TensorView view_B, - ComputeType identity, - ReduceOp reduce, - TransformOp transform) { - - if (view_A.extent() != view_B.extent()) { - throw std::runtime_error("Tensor extents must match."); - } - - for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { - - typename Layout::TensorCoord coord; - cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); - - if (view_A.contains(coord)) { - Element a = view_A.at(coord); - Element b = view_B.at(coord); - identity = reduce(identity, transform(a, b)); - } - } - - return identity; -} - -/// Helper to compute the sum of the elements of a tensor -template < - typename Element, - typename Layout, - typename ComputeType = Element -> -ComputeType TensorSum( - TensorView view, - ComputeType identity = ComputeType() -) { - - plus reduce; - NumericConverter transform; - - return TensorTransformReduce( - view, identity, reduce, transform); -} - -/// Helper to compute the sum of the squares of the elements of a tensor -template < - typename Element, - typename Layout, - typename ComputeType = Element -> -ComputeType TensorSumSq( - TensorView view, - ComputeType identity = ComputeType() -) { - - plus reduce; - magnitude_squared transform; - - return TensorTransformReduce( - view, identity, reduce, transform); -} - -/// Helper to compute the norm of the elements of a tensor. -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorNorm( - TensorView view, - ComputeType identity = ComputeType() -) { - - return std::sqrt(TensorSumSq(view, identity)); -} - -/// Helper to compute the sum of the squares of the differences of two tensors -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorSumSqDiff( - TensorView view_A, - TensorView view_B, - ComputeType identity = ComputeType() -) { - - plus reduce; - magnitude_squared_difference transform; - - return TensorTransformReduce( - view_A, view_B, identity, reduce, transform); -} - - -/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -template < - typename Element, - typename Layout, - typename ComputeType = double -> -ComputeType TensorNormDiff( - TensorView view_A, - TensorView view_B, - ComputeType identity = ComputeType() -) { - - return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp deleted file mode 100644 index ea711466df86703aae1702605a928754c9f4e944..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp +++ /dev/null @@ -1,203 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Provides several functions for filling tensors with data. -*/ - -#pragma once - -// Standard Library includes -#include -#include -#include - -// Cute includes -#include "cute/tensor.hpp" - -// Cutlass includes -#include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/quaternion.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace reference { -namespace host { - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Tensor reductions -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename Tensor, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - Tensor view, - ComputeType identity, - ReduceOp reduce, - TransformOp transform -) { - - for (int64_t idx = 0; idx < cute::size(view); ++idx) { - identity = reduce(identity, transform(view(idx))); - } - - return identity; -} - -/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -/// workspace -template < - typename TensorA, - typename TensorB, - typename ComputeType, - typename ReduceOp, - typename TransformOp -> -ComputeType TensorTransformReduce( - TensorA view_A, - TensorB view_B, - ComputeType identity, - ReduceOp reduce, - TransformOp transform) { - - if (cute::size(view_A) != cute::size(view_B)) { - throw std::runtime_error("Tensor sizes must match."); - } - - for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { - identity = reduce(identity, transform(view_A(idx), view_B(idx))); - } - - return identity; -} - -/// Helper to compute the sum of the elements of a tensor -template < - typename Tensor, - typename ComputeType = typename Tensor::value_type -> -ComputeType TensorSum( - Tensor view, - ComputeType identity = ComputeType() -) { - - plus reduce; - NumericConverter transform; - - return TensorTransformReduce( - view, identity, reduce, transform); -} - -/// Helper to compute the sum of the squares of the elements of a tensor -template < - typename Tensor, - typename ComputeType = typename Tensor::value_type -> -ComputeType TensorSumSq( - Tensor view, - ComputeType identity = ComputeType() -) { - - plus reduce; - magnitude_squared transform; - - return TensorTransformReduce( - view, identity, reduce, transform); -} - -/// Helper to compute the norm of the elements of a tensor. -template < - typename Tensor, - typename ComputeType = double -> -ComputeType TensorNorm( - Tensor view, - ComputeType identity = ComputeType() -) { - - return std::sqrt(TensorSumSq(view, identity)); -} - -/// Helper to compute the sum of the squares of the differences of two tensors -template < - typename TensorA, - typename TensorB, - typename ComputeType = double -> -ComputeType TensorSumSqDiff( - TensorA view_A, - TensorB view_B, - ComputeType identity = ComputeType() -) { - - plus reduce; - magnitude_squared_difference transform; - - return TensorTransformReduce( - view_A, view_B, identity, reduce, transform); -} - - -/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -template < - typename TensorA, - typename TensorB, - typename ComputeType = double -> -ComputeType TensorNormDiff( - TensorA view_A, - TensorB view_B, - ComputeType identity = ComputeType() -) { - - return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h deleted file mode 100644 index 09b1aff9c0ea9922af46c928a3dd61595be2e4cd..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h +++ /dev/null @@ -1,215 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for TRMM in host-side code. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/mma.h" -#include "cutlass/util/host_tensor.h" - -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - DiagType DiagTypeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_trmm( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - TensorRef tensor_d, - ComputeType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - static_assert(SideModeA != SideMode::kInvalid - , "Side Mode can either be Left or Right."); - - static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper - , "Fill Mode can either be Lower or Upper."); - - using CompareOp = typename TrMatrixCompareOp::Type; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - // Assuming correct k-dimension value is passed - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - CompareOp compare_op; - - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a = ElementA(); - ElementB b = ElementB(); - - if (SideModeA == SideMode::kLeft) { - a = (compare_op(row, k_block)) ? - (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); - if (row == k_block && DiagTypeA == DiagType::kUnit) { - a = ElementA(1); - } - b = tensor_b.at(MatrixCoord(k_block, col)); - } else if (SideModeA == SideMode::kRight) { - a = tensor_b.at(MatrixCoord(row, k_block)); - b = (compare_op(k_block, col)) ? - tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); - if (k_block == col && DiagTypeA == DiagType::kUnit) { - b = ElementA(1); - } - } - - ComputeType compute_a(cast_if_scalar(a)); - ComputeType compute_b(cast_if_scalar(b)); - - accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j])); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - SideMode SideModeA, - FillMode FillModeA, - DiagType DiagTypeA, - typename ElementB, - typename LayoutB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = cutlass::arch::OpMultiplyAdd -> -struct Trmm; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct Trmm { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_trmm>( - problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h deleted file mode 100644 index e8db2a4deaf8608882595d68e611f8ae79e134e8..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h +++ /dev/null @@ -1,262 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for complex-valued TRMM in host-side code. - - -*/ - -#pragma once - -#include "cutlass/blas3.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/util/reference/host/gemm.h" - -namespace cutlass { -namespace reference { -namespace host { - -/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -/// objects. -template < - typename ElementA, - typename LayoutA, - ComplexTransform TransformA, - SideMode SideModeA, - FillMode FillModeA, - DiagType DiagTypeA, - typename ElementB, - typename LayoutB, - ComplexTransform TransformB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = multiply_add, - typename ConvertOp = NumericConverter -> -void compute_trmm_complex( - gemm::GemmCoord problem_size, - ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - TensorRef tensor_d, - ComputeType initial_accum) { - - static_assert( - LayoutA::kRank == 2 && - LayoutC::kRank == 2, "Tensors must be of rank 2"); - - static_assert(SideModeA != SideMode::kInvalid - , "Side Mode can either be Left or Right."); - - static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper - , "Fill Mode can either be Lower or Upper."); - - using CompareOp = typename TrMatrixCompareOp::Type; - - // Note: batch is ignored. - int const M = problem_size.m(); - int const N = problem_size.n(); - // Assuming correct k-dimension value is passed - int const K = problem_size.k(); - - // Blocking necessary to speedup reference implementation - int const Mblock = 16; - int const Nblock = 16; - - ConvertOp convert_op; - InnerProductOp inner_product_op; - CompareOp compare_op; - - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { - - ComputeType accum[Mblock][Nblock]; - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } - - for (int k_block = 0; k_block < K; ++k_block) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - if (row < M && col < N) { - ElementA a = ElementA(); - ElementB b = ElementB(); - - if (SideModeA == SideMode::kLeft) { - a = (compare_op(row, k_block)) ? - (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); - if (row == k_block && DiagTypeA == DiagType::kUnit) { - a = ElementA(1); - } - b = tensor_b.at(MatrixCoord(k_block, col)); - } else if (SideModeA == SideMode::kRight) { - a = tensor_b.at(MatrixCoord(row, k_block)); - b = (compare_op(k_block, col)) ? - tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); - if (k_block == col && DiagTypeA == DiagType::kUnit) { - b = ElementA(1); - } - } - - ComputeType a_ik = ComputeType(a); - ComputeType b_kj = ComputeType(b); - - // Conjugate, and hence hermitian, is only allowed for the triangular matrix - if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { - b_kj = conj(b_kj); - } - - accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); - } - } - } - } - - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; - - MatrixCoord coord = MatrixCoord(row, col); - - if (row < M && col < N) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j])); - } - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ElementA, - typename LayoutA, - ComplexTransform TransformA, - SideMode SideModeA, - FillMode FillModeA, - DiagType DiagTypeA, - typename ElementB, - typename LayoutB, - ComplexTransform TransformB, - typename ElementC, - typename LayoutC, - typename ScalarType, - typename ComputeType, - typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -> -struct TrmmComplex; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for multiply-add -template -struct TrmmComplex { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_trmm_complex>( - problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for gaussian multiply-add -template -struct TrmmComplex { - - void operator()(gemm::GemmCoord problem_size, ScalarType alpha, - TensorRef tensor_a, - TensorRef tensor_b, - TensorRef tensor_d, - ComputeType initial_accum = ComputeType(0)) { - static_assert( - LayoutA::kRank == 2 && LayoutC::kRank == 2, - "Tensors must be of rank 2"); - - compute_trmm_complex>( - problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h deleted file mode 100644 index 0ce1d8a65fdd66ace69f91525b678dd6ad132d24..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h +++ /dev/null @@ -1,270 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -* -**************************************************************************************************/ -#pragma once - -#include "cutlass/core_io.h" -#include "cutlass/tensor_view.h" -#include "cutlass/tensor_view_planar_complex.h" -#include "cutlass/complex.h" - -namespace cutlass { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Helper to write the least significant rank of a TensorView -template < - typename Element, - typename Layout -> -inline std::ostream & TensorView_WriteLeastSignificantRank( - std::ostream& out, - TensorView const& view, - Coord const &start_coord, - int rank, - std::streamsize width) { - - for (int idx = 0; idx < view.extent(rank); ++idx) { - - Coord coord(start_coord); - coord[rank] = idx; - - if (idx) { - out.width(0); - out << ", "; - } - if (idx || coord) { - out.width(width); - } - out << ScalarIO(view.at(coord)); - } - - return out; -} - -/// Helper to write a rank of a TensorView -template < - typename Element, - typename Layout -> -inline std::ostream & TensorView_WriteRank( - std::ostream& out, - TensorView const& view, - Coord const &start_coord, - int rank, - std::streamsize width) { - - // If called on the least significant rank, write the result as a row - if (rank + 1 == Layout::kRank) { - return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); - } - - // Otherwise, write a sequence of rows and newlines - for (int idx = 0; idx < view.extent(rank); ++idx) { - - Coord coord(start_coord); - coord[rank] = idx; - - if (rank + 2 == Layout::kRank) { - // Write least significant ranks asa matrix with rows delimited by "\n" - if (idx) { - out << ",\n"; - } - TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); - } - else { - // Higher ranks are separated by newlines - if (idx) { - out << ",\n\n"; - } - TensorView_WriteRank(out, view, coord, rank + 1, width); - } - } - - return out; -} - -/// Helper to write the least significant rank of a TensorView -template < - typename Element, - typename Layout -> -inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( - std::ostream& out, - TensorViewPlanarComplex const& view, - Coord const &start_coord, - int rank, - std::streamsize width) { - - for (int idx = 0; idx < view.extent(rank); ++idx) { - - Coord coord(start_coord); - coord[rank] = idx; - - if (idx) { - out.width(0); - out << ", "; - } - if (idx || coord) { - out.width(width); - } - - complex x = view.at(coord); - out << x; - } - - return out; -} - -/// Helper to write a rank of a TensorView -template < - typename Element, - typename Layout -> -inline std::ostream & TensorViewPlanarComplex_WriteRank( - std::ostream& out, - TensorViewPlanarComplex const& view, - Coord const &start_coord, - int rank, - std::streamsize width) { - - // If called on the least significant rank, write the result as a row - if (rank + 1 == Layout::kRank) { - return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); - } - - // Otherwise, write a sequence of rows and newlines - for (int idx = 0; idx < view.extent(rank); ++idx) { - - Coord coord(start_coord); - coord[rank] = idx; - - if (rank + 2 == Layout::kRank) { - // Write least significant ranks asa matrix with rows delimited by ";\n" - if (idx) { - out << ";\n"; - } - TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); - } - else { - // Higher ranks are separated by newlines - if (idx) { - out << "\n"; - } - TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); - } - } - - return out; -} - -} // namespace detail - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Prints human-readable representation of a TensorView to an ostream -template < - typename Element, - typename Layout -> -inline std::ostream& TensorViewWrite( - std::ostream& out, - TensorView const& view) { - - // Prints a TensorView according to the following conventions: - // - least significant rank is printed as rows separated by ";\n" - // - all greater ranks are delimited with newlines - // - // The result is effectively a whitespace-delimited series of 2D matrices. - - return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); -} - -/// Prints human-readable representation of a TensorView to an ostream -template < - typename Element, - typename Layout -> -inline std::ostream& operator<<( - std::ostream& out, - TensorView const& view) { - - // Prints a TensorView according to the following conventions: - // - least significant rank is printed as rows separated by ";\n" - // - all greater ranks are delimited with newlines - // - // The result is effectively a whitespace-delimited series of 2D matrices. - - return TensorViewWrite(out, view); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Prints human-readable representation of a TensorView to an ostream -template < - typename Element, - typename Layout -> -inline std::ostream& TensorViewWrite( - std::ostream& out, - TensorViewPlanarComplex const& view) { - - // Prints a TensorView according to the following conventions: - // - least significant rank is printed as rows separated by ";\n" - // - all greater ranks are delimited with newlines - // - // The result is effectively a whitespace-delimited series of 2D matrices. - - return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); -} - -/// Prints human-readable representation of a TensorView to an ostream -template < - typename Element, - typename Layout -> -inline std::ostream& operator<<( - std::ostream& out, - TensorViewPlanarComplex const& view) { - - // Prints a TensorView according to the following conventions: - // - least significant rank is printed as rows separated by ";\n" - // - all greater ranks are delimited with newlines - // - // The result is effectively a whitespace-delimited series of 2D matrices. - - return TensorViewWrite(out, view); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h deleted file mode 100644 index 5dfbfe274dec368cfac291a1c78ece6ffb203c72..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h +++ /dev/null @@ -1,238 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Type traits for common CUDA types -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/numeric_types.h" -#include "cutlass/complex.h" - -namespace cutlass { -struct half_t; - -template -struct TypeTraits { - typedef T host_type; - typedef T device_type; - static inline T remove_negative_zero(T x) { return x; } - static inline T to_print(T x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_8I; - typedef int8_t host_type; - typedef int8_t device_type; - typedef int8_t integer_type; - typedef uint8_t unsigned_type; - static inline int8_t remove_negative_zero(int8_t x) { return x; } - static inline int to_print(int8_t x) { return (int)x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_8I; - typedef uint8_t host_type; - typedef uint8_t device_type; - typedef uint8_t integer_type; - typedef uint8_t unsigned_type; - static inline uint8_t remove_negative_zero(uint8_t x) { return x; } - static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_32I; - typedef int host_type; - typedef int device_type; - typedef int32_t integer_type; - typedef uint32_t unsigned_type; - static inline int32_t remove_negative_zero(int32_t x) { return x; } - static inline int to_print(int x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_32I; - typedef unsigned host_type; - typedef unsigned device_type; - typedef uint32_t integer_type; - typedef uint32_t unsigned_type; - static inline uint32_t remove_negative_zero(uint32_t x) { return x; } - static inline uint32_t to_print(uint32_t x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_8I; - typedef int64_t host_type; - typedef int64_t device_type; - typedef int64_t integer_type; - typedef uint64_t unsigned_type; - static inline int64_t remove_negative_zero(int64_t x) { return x; } - static inline int64_t to_print(int64_t x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_8I; - typedef uint64_t host_type; - typedef uint64_t device_type; - typedef uint64_t integer_type; - typedef uint64_t unsigned_type; - static inline uint64_t remove_negative_zero(uint64_t x) { return x; } - static inline uint64_t to_print(uint64_t x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_16F; - typedef half_t host_type; - typedef half_t device_type; - typedef int16_t integer_type; - typedef uint16_t unsigned_type; - static inline half_t remove_negative_zero(half_t x) { - return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); - } - static inline half_t to_print(half_t x) { return x; } - static inline device_type to_device(half_t x) { return reinterpret_cast(x); } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_32F; - typedef float host_type; - typedef float device_type; - typedef int32_t integer_type; - typedef uint32_t unsigned_type; - static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } - static inline float to_print(float x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -template <> -struct TypeTraits { - static cudaDataType_t const cublas_type = CUDA_R_64F; - typedef double host_type; - typedef double device_type; - typedef int64_t integer_type; - typedef uint64_t unsigned_type; - static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } - static inline double to_print(double x) { return x; } - static inline device_type to_device(host_type x) { return x; } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Complex types -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct TypeTraits > { - static cudaDataType_t const cublas_type = CUDA_C_16F; - typedef complex host_type; - typedef complex device_type; - typedef int16_t integer_type; - typedef uint16_t unsigned_type; - static inline device_type to_device(complex x) { return reinterpret_cast(x); } -}; - -template <> -struct TypeTraits > { - static cudaDataType_t const cublas_type = CUDA_C_16F; - typedef complex host_type; - typedef complex device_type; - typedef int16_t integer_type; - typedef uint16_t unsigned_type; - static inline complex remove_negative_zero(complex x) { - return complex( - real(x) == -0_hf ? 0_hf : real(x), - imag(x) == -0_hf ? 0_hf : imag(x) - ); - } - static inline complex to_print(complex x) { return x; } - static inline device_type to_device(complex x) { return reinterpret_cast(x); } -}; - -template <> -struct TypeTraits > { - - static cudaDataType_t const cublas_type = CUDA_C_32F; - typedef complex host_type; - typedef complex device_type; - typedef int64_t integer_type; - typedef uint64_t unsigned_type; - - static inline complex remove_negative_zero(complex x) { - return complex( - real(x) == -0.f ? 0.f : real(x), - imag(x) == -0.f ? 0.f : imag(x) - ); - } - - static inline complex to_print(complex x) { return x; } - static inline device_type to_device(complex x) { return reinterpret_cast(x); } -}; - -template <> -struct TypeTraits > { - static cudaDataType_t const cublas_type = CUDA_C_64F; - typedef complex host_type; - typedef complex device_type; - struct integer_type { int64_t real, imag; }; - struct unsigned_type { uint64_t real, imag; }; - static inline complex remove_negative_zero(complex x) { - return complex( - real(x) == -0.0 ? 0.0 : real(x), - imag(x) == -0.0 ? 0.0 : imag(x) - ); - } - static inline complex to_print(complex x) { return x; } - static inline device_type to_device(complex x) { return reinterpret_cast(x); } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py b/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py deleted file mode 100644 index 6541ce1b26722ff1f0dba0b4c034067a62f9b96d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py +++ /dev/null @@ -1,356 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - - -""" -Given a set of test files to be included in a CMake target, this script extracts -the TEST definitions from each file, writes them into new files, and prints the names -of the new files so that they can be processed as part of a new CMake target. - -For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST -definitions, respectively, this script would produce: - test_a_000.cu - test_a_001.cu - test_a_002.cu - test_b_000.cu - test_b_001.cu - -The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. -We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running -"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix -for that test. All subsequent lines are considered to be part of the TEST definition until the -number of starting function braces ('{') match the number of closing function braces ('}'). When -these counts are equal, the TEST definition is considered to be completed. At this point, we return -to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text -following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing -off #if statements, as is common in unit tests.). - -A state machine illustrating this algorithm at a high level is provided in the source below. - -Example: Suppose an input test `test.cu` has the following source: - // COPYRIGHT - #include - - #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - - // Test #1 - TEST(SM90_a, 256x128x64_2x2x1) { - std::cout << "Test #1" << std::endl; - } - - // Test #2 - TEST(SM90_b, 256x128x64_1x1x1) { - std::cout << "Test #2" << std::endl; - } - - #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - -The contents of the two resulting test files will be: - $ cat test_000.cu - // COPYRIGHT - #include - - #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - - // Test #1 - TEST(SM90_a, 256x128x64_2x2x1) { - std::cout << "Test #1" << std::endl; - } - - // Test #2 - - #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - $ cat test_001.cu - // COPYRIGHT - #include - - #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - - // Test #1 - - // Test #2 - TEST(SM90_b, 256x128x64_1x1x1) { - std::cout << "Test #2" << std::endl; - } - - #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - -Notice that each of test_000.cu and test_001.cu contain comments that appear outside -the TEST definitions not included in each file. This is by design, as these -would be considered "filler" text. - -As expected, some cases can't be handled. Below is a non-exhaustive list: - 1. New TEST following the closing '}' of a TEST case on the same line: - TEST(x, y) { - // Do stuff - } TEST(a, b) { - - In this case, "TEST(a, b) {" will be ignored - - 2. Preprocessor macros that occur midway through a test case and extend - beyond the conclusion of a testcase - - Example: - TEST(a, b) { - // Do stuff - #if X - // Do more stuff - } - #else - // Do other stuff - } - #endif -""" - - -import argparse -import enum -import os - - -parser = argparse.ArgumentParser() -parser.add_argument("cmake_target", type=str, - help="Name of the CMake target being generated.") -parser.add_argument("src_dir", type=str, - help="Path to the directory containing test files.") -parser.add_argument("--src_files", nargs='+', - help="Files containing TEST instances to split.") -parser.add_argument("--max_tests_per_file", type=int, default=1, - help="Maximum number of TEST instances per file.") -parser.add_argument("--dst_dir", type=str, - help="Path to the directory to which to write new test files. If not set, uses src_dir.") -args = parser.parse_args() - - -if args.dst_dir == None: - args.dst_dir = args.src_dir - - -class Testcase: - """ - Lightweight tracker of test-case processing status - """ - def __init__(self, prefix_text): - # Any text that preceded the TEST definition that was - # not part of another TEST definition - self.prefix = prefix_text - - # Any text within the TEST definition - self.test = "" - - # Any text that follows the completion of the TEST definition - # and is not included in other TEST definitions - self.suffix = "" - - # Whether the test's definition has concluded - self.completed = False - - # Current balance of opening and closing curly brackets in - # the TEST definition. '{' increments the count and '}' decrements it. - # A value of 0 (when self.completed == False) indicates that the test - # has completed. - self.curly_bracket_balance = 0 - - -class ParseState(enum.Enum): - """ - State machine for processing. - Transitions occur on each line encountered in the soruce file - - - Line does not contain 'TEST(' - +----+ - | | - | v 'TEST(' - +--------+ encountered +--------------------------+ - ------>| Filler | -----------------------> | TestDeclaredWaitingStart | - +--------+ +--------------------------+ - ^ | - Number of '{' | | First '{' encountered - equals number of | +--------+ | - '}' encountered +-----------| InTest | <------------------+ - +--------+ - | ^ - | | - +----+ - Number of '{' encountered - exceeds number of '}' encountered - """ - - - # Any text that is not part of a TEST case - Filler = 0 - - # Processing text within the first { of the TEST case - # and before the en of the final } of the TEST case - InTest = 1 - - # Processing text from the start of the TEST definition - # but before the first {. This could occur if the opening { - # occurs on a separate line than the TEST definition. - TestDeclaredWaitingStart = 2 - - -cmake_src_list = [] -for filename in args.src_files: - if '.' not in filename: - # Add any non-filename arguments to the command list by default - cmake_src_list.append(filename) - continue - - if '/' in filename: - raise Exception( - f"Source files passed to {__file__} must be within the same directory " - "as the CMakeLists defining the target using the files. " - f"Provided path {filename} is in a different directory.") - - full_filename = os.path.join(args.src_dir, filename) - with open(full_filename, 'r') as infile: - lines = infile.readlines() - - # Find the number of instances of "TEST(" - ntest = sum([1 for line in lines if "TEST(" in line]) - - if ntest <= args.max_tests_per_file: - # File contains fewer than max_tests_per_file TEST instances. It does - # not need to be split - cmake_src_list.append(filename) - continue - - # Current state of the parsing state machine. We start with filler text - state = ParseState.Filler - - # List of individual TESTs found - tests = [] - - # Ongoing text that is not included in a TEST definition. This will serve - # as the prefix for any yet-to-be encountered TEST definitions. - filler_text = "" - - def add_filler_text(text): - global filler_text - # Add new text to the ongoing filler text and to the suffixes of - # any completed tests - filler_text += text - for i in range(len(tests)): - if tests[i].completed: - tests[i].suffix += text - - for line in lines: - if state == ParseState.Filler: - # We are not currently within a TEST definition. - - if 'TEST(' in line: - # We have encountered a new TEST( case. Any text preceding this - # must be added to the filler text (e.g., if we have a line of the form: - # "static constexpr int Val = 4; TEST(blah) {" - # then "static constexpr int Val = 4;" needs to be included in filler - # text, as it could be used by subsequent tests.) - splits = line.split('TEST') - - # There should not be more than one TEST definition on a given line - assert len(splits) <= 2 - - if len(splits) > 1: - if not splits[0].isspace(): - # Only add text to filler if there are non-whitespace charcters - # preceding the TEST definition in the line - filler_text += splits[0] - - # The new line is just the TEST-related line - line = 'TEST' + splits[-1] - - # Add tests and transtion to TestDeclaredWaitingStart state. - # Do not add the line to the test text of the new test case; this - # will be done in either the TestDeclaredWaitingStart state processing - # below or in the InTest state processing below. - tests.append(Testcase(filler_text)) - state = ParseState.TestDeclaredWaitingStart - else: - # Any remaining filler text is added to the running filler_text - # which will be used as the prefix for any new tests, and to the - # suffix of any completed tests - add_filler_text(line) - - if state == ParseState.TestDeclaredWaitingStart: - # We have seen a TEST definition but have not yet seen its opening {. - - if '{' in line: - # The first curly bracket for the TEST definition has been found. - # Advance to state InTests. Do not add the line to the test's text - # or change the curly-brace balance of the test; these will be done - # when processing the state == ParseState.InTest condition below. - state = ParseState.InTest - else: - tests[-1].test += line - - if state == ParseState.InTest: - # We are currently within a TEST definition. - # Process lines character-by-character looking for opening and closing - # braces. If we reach parity between opening and closing braces, the - # test is considered done. - filler_text_to_add = "" - for char in line: - if not tests[-1].completed: - tests[-1].test += char - if char == '{': - tests[-1].curly_bracket_balance += 1 - elif char == '}': - tests[-1].curly_bracket_balance -= 1 - if tests[-1].curly_bracket_balance == 0: - tests[-1].completed = True - else: - filler_text_to_add += char - - if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): - add_filler_text('\n' + filler_text_to_add) - - if tests[-1].completed: - state = ParseState.Filler - - # Write out the new files for tests - filename_prefix, filename_suffix = filename.split('.') - for i, test in enumerate(tests): - assert test.completed - new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix - full_new_filename = os.path.join(args.dst_dir, new_filename) - - # Replace any '\' with '/'. CMake doesn't like '\'. - full_new_filename = full_new_filename.replace('\\', '/') - - with open(full_new_filename, 'w') as outfile: - outfile.write(test.prefix + test.test + test.suffix) - cmake_src_list.append(full_new_filename) - - -for cmake_file in cmake_src_list: - print(cmake_file) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json deleted file mode 100644 index 4899badb63d45293425e2164944268b6058af95d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "version": 1, - "license": "MIT", - "python-depends": [], - "backend": { - "type": "cuda", - "archs": [ - "9.0a" - ] - } -} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/testing/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/testing/__init__.py deleted file mode 100644 index 13a9d78dea58a6492183f9ddc50f1510a679cbe6..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/testing/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import bench, numeric, utils -from .bench import * -from .numeric import * -from .utils import * diff --git a/build/torch29-cxx11-cu130-x86_64-linux/testing/bench.py b/build/torch29-cxx11-cu130-x86_64-linux/testing/bench.py deleted file mode 100644 index 2c752da2d3bb0aba7e03ef1921428432b396917a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/testing/bench.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import sys -import torch - - -def bench(fn, num_warmups: int = 5, num_tests: int = 10, - high_precision: bool = False): - # Flush L2 cache with 256 MB data - torch.cuda.synchronize() - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') - cache.zero_() - - # Warmup - for _ in range(num_warmups): - fn() - - # Add a large kernel to eliminate the CPU launch overhead - if high_precision: - x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - x @ y - - # Testing - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for i in range(num_tests): - fn() - end_event.record() - torch.cuda.synchronize() - - return start_event.elapsed_time(end_event) / num_tests / 1e3 - - -class empty_suppress: - def __enter__(self): - return self - - def __exit__(self, *_): - pass - - -class suppress_stdout_stderr: - def __enter__(self): - self.outnull_file = open(os.devnull, 'w') - self.errnull_file = open(os.devnull, 'w') - - self.old_stdout_fileno_undup = sys.stdout.fileno() - self.old_stderr_fileno_undup = sys.stderr.fileno() - - self.old_stdout_fileno = os.dup(sys.stdout.fileno()) - self.old_stderr_fileno = os.dup(sys.stderr.fileno()) - - self.old_stdout = sys.stdout - self.old_stderr = sys.stderr - - os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) - os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) - - sys.stdout = self.outnull_file - sys.stderr = self.errnull_file - return self - - def __exit__(self, *_): - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr - - os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) - os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) - - os.close(self.old_stdout_fileno) - os.close(self.old_stderr_fileno) - - self.outnull_file.close() - self.errnull_file.close() - - -def bench_kineto(fn, kernel_names, num_tests: int = 30, - suppress_kineto_output: bool = False, - trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tuple = isinstance(kernel_names, tuple) - - # Skip profiling - # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer - if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): - return (1, ) * len(kernel_names) if is_tuple else 1 - - # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle - flush_l2_size = int(8e9 // 4) - - # For some auto-tuning kernels with prints - fn() - - # Profile - suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress - with suppress(): - schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) - with profiler: - for i in range(2): - for _ in range(num_tests): - if flush_l2: - torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() - fn() - profiler.step() - - # Parse the profiling table - prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') - kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names - if not with_multiple_kernels: - for name in kernel_names: - assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' - - # Save chrome traces - if trace_path is not None: - profiler.export_chrome_trace(trace_path) - - # Return average kernel times - units = {'ms': 1e3, 'us': 1e6} - kernel_times = [] - for name in kernel_names: - total_time = 0 - total_num = 0 - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - num_str = line.split()[-1] - for unit, scale in units.items(): - if unit in time_str: - total_time += float(time_str.replace(unit, '')) / scale * int(num_str) - total_num += int(num_str) - break - kernel_times.append(total_time / total_num if total_num > 0 else 0) - - return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/testing/numeric.py b/build/torch29-cxx11-cu130-x86_64-linux/testing/numeric.py deleted file mode 100644 index a42c4318db47593c47a4ea89fbdbcb1ffb5cd30e..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/testing/numeric.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from typing import Iterable - - -def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - if denominator == 0: # Which means that all elements in x and y are 0 - return 0.0 - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def count_bytes(*tensors): - total = 0 - for t in tensors: - if isinstance(t, (tuple, list)): - total += count_bytes(*t) - elif t is not None: - total += t.numel() * t.element_size() - return total diff --git a/build/torch29-cxx11-cu130-x86_64-linux/testing/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/testing/utils.py deleted file mode 100644 index 2d202d4192ed385f986ac5cc216acc69378d8ea9..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/testing/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import functools -import os -import torch -from typing import Callable - -def get_arch_major() -> int: - major, minor = torch.cuda.get_device_capability() - return major - - -def test_filter(condition: Callable): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if condition(): - func(*args, **kwargs) - else: - print(f'{func.__name__}:') - print(f' > Filtered by {condition}') - print() - return wrapper - return decorator - - -def ignore_env(name: str, condition: Callable): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if condition(): - saved = os.environ.pop(name, None) - func(*args, **kwargs) - if saved is not None: - os.environ[name] = saved - else: - func(*args, **kwargs) - - return wrapper - return decorator diff --git a/build/torch29-cxx11-cu130-x86_64-linux/utils/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/utils/__init__.py deleted file mode 100644 index e8f859a20726fcc0ea32c54ed8df37b19b3960a4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import math, layout -from .layout import * -from .math import * diff --git a/build/torch29-cxx11-cu130-x86_64-linux/utils/layout.py b/build/torch29-cxx11-cu130-x86_64-linux/utils/layout.py deleted file mode 100644 index a6bc29d9aaae296a83b8c3546b832a083ade6b28..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/utils/layout.py +++ /dev/null @@ -1,25 +0,0 @@ -from .._ops import ops - - -def get_mk_alignment_for_contiguous_layout(): - return ops.get_mk_alignment_for_contiguous_layout() - - -def get_tma_aligned_size(mn: int, element_size: int): - return ops.get_tma_aligned_size(mn, element_size).item() - - -def get_mn_major_tma_aligned_tensor(sf): - return ops.get_mn_major_tma_aligned_tensor(sf) - - -def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): - return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) - - -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): - return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) - - -get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout -get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch29-cxx11-cu130-x86_64-linux/utils/math.py b/build/torch29-cxx11-cu130-x86_64-linux/utils/math.py deleted file mode 100644 index c65026e54b87faf34b498d14d3f81a94759615f4..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/utils/math.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from typing import Tuple - - -def ceil_div(x: int, y: int) -> int: - return (x + y - 1) // y - - -def align(x: int, y: int) -> int: - return ceil_div(x, y) * y - - -def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) - - -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - padded_n = align(n, gran_k) - x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) - x_padded[:, :n] = x - x_view = x_padded.view(m, -1, gran_k) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - sf = x_amax / 448.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf - - -def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(0) % gran_k == 0 - m, n = x.shape - x_view = x.view(-1, gran_k, n) - x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) - sf = x_amax / 448.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf - - -def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) - - -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) - x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled, sf.squeeze() - - -def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: - ax = x.abs().clamp_max(6.0) - # {0, 0.5, 1, 1.5, 2, 3, 4, 6} - # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 - boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], - device=x.device, dtype=ax.dtype) - idx = torch.bucketize(ax, boundaries) - code = idx.to(torch.uint8) - sign = (x < 0) & (idx != 0) - code = code | (sign.to(torch.uint8) << 3) - return code # uint8, 0..15 - - -def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - assert n % 2 == 0 - padded_n = align(n, gran_k) - x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) - x_padded[:, :n] = x - x_view = x_padded.view(m, -1, gran_k) - x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) - sf = x_amax / 6.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - x_scaled = x_view * (1.0 / sf.unsqueeze(2)) - codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) - codes2 = codes.view(m, padded_n // 2, 2) - packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 - return packed[:, :n // 2].contiguous(), sf - - -def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: - assert a.dtype == torch.uint8 - assert a.dim() == 2 - m, n2 = a.shape - n = n2 * 2 - assert (m % 2) == 0 - lo = a & 0x0F - hi = (a >> 4) & 0x0F - codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) - codes[:, 0::2], codes[:, 1::2] = lo, hi - codes_t = codes.transpose(0, 1).contiguous() - codes2 = codes_t.view(n, m // 2, 2) - out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) - return out.contiguous() \ No newline at end of file